Spaces:
Sleeping
Sleeping
File size: 7,885 Bytes
7203787 056cf7b 7203787 e8c9acc 7203787 88875f7 7203787 056cf7b 7203787 48661cd 7203787 1b91307 48661cd 7203787 1b91307 7203787 1b91307 7203787 056cf7b 7203787 48661cd 7203787 671787b 7203787 1b91307 7203787 e8c9acc 7203787 e8c9acc 7203787 e8c9acc 41a051f 48661cd 41a051f 48661cd e8c9acc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
environment.py (Task 3 β Rule Checker)
-----------------------------------------
OpenEnv-compliant RL environment.
Episode setup
βββββββββββββ
- A Solidity contract is selected that contains at least one function
violating a known property.
- The agent sees: contract description + the property in natural English.
- The agent must identify which function breaks that property.
Observation at reset
ββββββββββββββββββββ
extra.property_english β the violated property in plain English
extra.hint β instructions for the agent
Actions & rewards
βββββββββββββββββ
list_functions -0.05 see all function names
get_function_metadata -0.05 signature / visibility / modifiers / params
get_function_code -0.10 full Solidity source of any function
get_state_variables -0.05 list or inspect state variables
get_call_graph -0.08 function call graph
get_property_specification -0.03 formal pre/post-condition version of property
submit_function terminal: +5.0 / +1.5 / -1.5 (ONE attempt)
repeated_query -0.40
Difficulty: Easy
The property text directly names the invariant broken; reading 2-3 functions
should let most agents identify the culprit quickly.
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Set
from data.data_loader import load_contracts, sample_task3_episode
from env.base_env import BaseEnv
from env.schemas import (
Action,
ActionType,
Observation,
Reward,
ResetResult,
StateResult,
StepResult,
)
from .grader import Task3Grader
from server.tasks.task3 import actions
TASK_ID = "task3_rule_checker"
AVAILABLE_ACTIONS = [
ActionType.LIST_FUNCTIONS,
ActionType.GET_FUNCTION_METADATA,
ActionType.GET_FUNCTION_CODE,
ActionType.GET_STATE_VARIABLE,
ActionType.GET_CALL_GRAPH,
ActionType.GET_PROPERTY_SPECIFICATION,
ActionType.SUBMIT_FUNCTION,
]
class Task3Environment(BaseEnv):
"""Task 3: Rule Checker β identify the function that violates a given property."""
def __init__(self, contracts_path: Optional[str] = None) -> None:
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
self._rng = random.Random()
self._max_steps = 20
# Episode state β initialised by reset()
self._contract: Dict[str, Any] = {}
self._target_fn: Dict[str, Any] = {}
self._grader: Optional[Task3Grader] = None
self._step_count: int = 0
self._cum_reward: float = 0.0
self._done: bool = False
self._query_hist: List[str] = []
self._seen: Set[str] = set()
# ββ OpenEnv interface βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reset(self, seed: Optional[int] = None) -> ResetResult:
if seed is not None:
self._rng.seed(seed)
self._contract, self._target_fn = sample_task3_episode(
self._contracts, self._rng
)
self._grader = Task3Grader(
target_function=self._target_fn,
property_specification=self._target_fn.get("property_specification", ""),
max_steps = self._max_steps
)
self._step_count = 0
self._cum_reward = 0.0
self._done = False
self._query_hist = []
self._seen = set()
obs = self._build_obs(
last_action=None,
last_result=(
f"New episode started.\n"
f"Contract : {self._contract['contract_name']}\n\n"
f"Property : {self._target_fn.get('property', '')}\n\n"
f"Find the function in this contract that violates the property above.\n"
f"Use list_functions then get_function_code to investigate.\n"
f"Submit with submit_function, params={{\"function_name\": \"...\"}}.\n"
f"Only ONE submission allowed."
),
)
return ResetResult(observation=obs, info={"task_id": TASK_ID})
def step(self, action: Action) -> StepResult:
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
if self._step_count > self._max_steps:
raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
self._step_count += 1
result_text, reward = self._dispatch(action)
self._cum_reward += reward.value
self._query_hist.append(f"[{action.action_type}] β {result_text[:100]}")
obs = self._build_obs(
last_action=action.action_type,
last_result=result_text,
)
return StepResult(
observation=obs,
reward=reward,
done=self._done,
info={"step": self._step_count, "cumulative_reward": self._cum_reward},
)
def state(self) -> StateResult:
return StateResult(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
target_function=self._target_fn.get("name", ""),
step_count=self._step_count,
cumulative_reward=self._cum_reward,
done=self._done,
query_history=list(self._query_hist),
)
# ββ Internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation:
return Observation(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
last_action=last_action,
last_action_result=last_result,
done=self._done,
extra={
"property_english": self._target_fn.get("property", ""),
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
"hint": (
"Read the property, then inspect function code to find which one violates it. "
"Submit with: submit_function, params={'function_name': '<name>'}. "
"ONE submission per episode."
),
},
)
def _qkey(self, at: str, params: Dict[str, Any]) -> str:
return f"{at}:{sorted(params.items())}"
def _is_repeated(self, key: str) -> bool:
if key in self._seen:
return True
self._seen.add(key)
return False
def _dispatch(self, action: Action) -> tuple[str, Reward]:
at = action.action_type
params = action.params
qkey = self._qkey(at, params)
# Mapping from ActionType to handler function
handlers = {
ActionType.LIST_FUNCTIONS: actions.list_functions,
ActionType.GET_FUNCTION_METADATA: actions.get_function_metadata,
ActionType.GET_FUNCTION_CODE: actions.get_function_code,
ActionType.GET_STATE_VARIABLE: actions.get_state_variable,
ActionType.GET_CALL_GRAPH: actions.get_call_graph,
ActionType.GET_PROPERTY_SPECIFICATION: actions.get_property_specification,
ActionType.SUBMIT_FUNCTION: actions.submit_function,
}
handler = handlers.get(at)
if handler is None:
return actions.unknown_action(self, qkey, params, at)
return handler(self, qkey, params) |