SmartContractAudit / server /tasks /task2 /environment.py
ajaxwin
refactor: Task3 reward model changed, agent adjusted for new model
48661cd
"""
environment.py (Task 2 – Property Discovery)
----------------------------------------------
OpenEnv-compliant RL environment.
Episode setup:
- One function from a Solidity contract that has a known property.
- The agent sees: contract description + function name + function signature.
- The agent must discover the natural-language property of the function.
Actions & rewards:
get_function_code -0.06 (always positive topic context)
get_function_natspec -0.08 (strongest hint β€” natspec has param/return docs)
get_file_natspec -0.03 (broad contract-level context)
get_related_functions -0.06 (shows callers/callees)
get_io -0.04 (structured input/output description)
get_similar_rule -0.20 (shows a similar property from another contract)
submit_property scored 0–5 (ONE attempt, ends episode)
repeated_query -0.40
Episode ends when:
- submit_property is called (scored), OR
- max_steps is reached without submission (reward = -1.0)
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Set
from math import log2, floor
from data.data_loader import load_contracts, sample_property_episode
from env.base_env import BaseEnv
from env.schemas import (
Action,
ActionType,
Observation,
Reward,
ResetResult,
StateResult,
StepResult,
)
from .grader import Task2Grader
from server.tasks.task2 import actions
TASK_ID = "task2_property_discovery"
AVAILABLE_ACTIONS = [
ActionType.GET_FUNCTION_CODE,
ActionType.GET_FUNCTION_NATSPEC,
ActionType.GET_FILE_NATSPEC,
ActionType.GET_RELATED_FUNCTIONS,
ActionType.GET_SIGNATURE,
ActionType.GET_SIMILAR_RULE,
ActionType.SUBMIT_PROPERTY,
]
class Task2Environment(BaseEnv):
"""Task 2: Property Discovery."""
def __init__(self, contracts_path: Optional[str] = None) -> None:
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
self._rng = random.Random()
self._max_steps: int = 40
# Episode state – initialised by reset()
self._contract: Dict[str, Any] = {}
self._target_fn: Dict[str, Any] = {}
self._grader: Optional[Task2Grader] = None
self._step_count: int = 0
self._cum_reward: float = 0.0
self._done: bool = False
self._query_hist: List[str] = []
self._seen: Set[str] = set()
# ── OpenEnv interface ────────────────────────────────────────────────────
def reset(self, seed: Optional[int] = None) -> ResetResult:
if seed is not None:
self._rng.seed(seed)
self._contract, self._target_fn = sample_property_episode(
self._contracts, self._rng
)
self._grader = Task2Grader(
function_name=self._target_fn["name"],
property=self._target_fn["property"],
n = floor(log2(len(self._contract["functions"])))
)
self._step_count = 0
self._cum_reward = 0.0
self._done = False
self._query_hist = []
self._seen = set()
obs = self._build_obs(
last_action=None,
last_result=(
f"New episode started.\n"
f"Contract : {self._contract['contract_name']}\n"
f"Function : {self._target_fn['name']} "
f"({self._target_fn.get('signature', '')})\n"
f"Your task : Discover the natural-language property of "
f"'{self._target_fn['name']}' and submit it with submit_property action."
),
)
return ResetResult(observation=obs, info={"task_id": TASK_ID})
def step(self, action: Action) -> StepResult:
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
if self._step_count > self._max_steps:
raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
self._step_count += 1
result_text, reward = self._dispatch(action)
self._cum_reward += reward.value
self._query_hist.append(f"[{action.action_type}] β†’ {result_text[:100]}")
obs = self._build_obs(
last_action=action.action_type,
last_result=result_text,
)
return StepResult(
observation=obs,
reward=reward,
done=self._done,
info={
"step": self._step_count,
"cumulative_reward": self._cum_reward,
},
)
def state(self) -> StateResult:
return StateResult(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
target_function=self._target_fn.get("name", ""),
step_count=self._step_count,
cumulative_reward=self._cum_reward,
done=self._done,
query_history=list(self._query_hist),
)
# ── Internal helpers ─────────────────────────────────────────────────────
def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation:
return Observation(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
last_action=last_action,
last_action_result=last_result,
done=self._done,
extra={
"target_function": self._target_fn.get("name", ""),
"target_signature": self._target_fn.get("signature", ""),
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
"hint": (
"Discover the property of the target function. "
"Use get_function_code, get_function_natspec, or get_similar_rule for hints. "
"Submit with submit_property, params={'property': '<your property text>'}. "
"ONE submission attempt only."
),
},
)
def _qkey(self, at: str, params: Dict[str, Any]) -> str:
return f"{at}:{sorted(params.items())}"
def _is_repeated(self, key: str) -> bool:
if key in self._seen:
return True
self._seen.add(key)
return False
def _dispatch(self, action: Action) -> tuple[str, Reward]:
at = action.action_type
params = action.params
qkey = self._qkey(at, params)
handlers = {
ActionType.GET_FUNCTION_CODE: actions.get_function_code,
ActionType.GET_FUNCTION_NATSPEC: actions.get_function_natspec,
ActionType.GET_FILE_NATSPEC: actions.get_file_natspec,
ActionType.GET_RELATED_FUNCTIONS: actions.get_related_functions_action,
ActionType.GET_SIGNATURE: actions.get_signature,
ActionType.GET_SIMILAR_RULE: actions.get_similar_rule_action,
ActionType.SUBMIT_PROPERTY: actions.submit_property,
}
handler = handlers.get(at)
if handler is None:
return actions.unknown_action(self, qkey, params, at)
return handler(self, qkey, params)