| from __future__ import annotations |
| import abc |
| import inspect |
| import random |
| import re |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, List, Tuple |
|
|
|
|
| |
| _INDEX_RE = re.compile(r'^(.*?)\[(.*?)\]$') |
|
|
|
|
| |
| |
| |
| class OptimizableField: |
| """Expose a concrete runtime attribute via get/set.""" |
| def __init__(self, |
| name: str, |
| getter: Callable[[], Any], |
| setter: Callable[[Any], None]): |
| self.name, self._get, self._set = name, getter, setter |
| def get(self) -> Any: return self._get() |
| def set(self, value: Any) -> None: self._set(value) |
|
|
|
|
| class PromptRegistry: |
| """Central registry for all runtime-patchable fields.""" |
| def __init__(self) -> None: |
| self.fields: Dict[str, OptimizableField] = {} |
| def register_field(self, field: OptimizableField): |
| self.fields[field.name] = field |
| |
| def get(self, name: str) -> Any: |
| return self.fields[name].get() |
| def set(self, name: str, value: Any): |
| self.fields[name].set(value) |
| def names(self) -> List[str]: |
| return list(self.fields.keys()) |
|
|
| |
| def register_path(self, root: Any, path: str, *, name: str|None=None): |
| """็จ็ฑปไผผ 'encoder.layers[3].dropout_p' ็ๅญ็ฌฆไธฒไธๆฌกๆงๆณจๅใ""" |
| key = name or path.split(".")[-1] |
| parent, leaf = self._walk(root, path) |
|
|
| def getter(): |
| return parent[leaf] if isinstance(parent, (list, dict)) else getattr(parent, leaf) |
|
|
| def setter(v): |
| if isinstance(parent, (list, dict)): |
| parent[leaf] = v |
| else: |
| setattr(parent, leaf, v) |
|
|
| field = OptimizableField(key, getter, setter) |
| self.register_field(field) |
| return field |
|
|
| def _walk(self, root, path: str, create_missing=False): |
| cur = root |
| parts = path.split(".") |
| for part in parts[:-1]: |
| m = _INDEX_RE.match(part) |
| if m: |
| attr, idx = m.groups() |
| cur = getattr(cur, attr) if attr else cur |
| idx = idx.strip() |
| if (idx.startswith("'") and idx.endswith("'")) or (idx.startswith('"') and idx.endswith('"')): |
| idx = idx[1:-1] |
| elif idx.isdigit(): |
| idx = int(idx) |
| cur = cur[idx] |
| else: |
| cur = getattr(cur, part) |
|
|
| |
| leaf = parts[-1] |
| m = _INDEX_RE.match(leaf) |
| if m: |
| attr, idx = m.groups() |
| parent = getattr(cur, attr) if attr else cur |
| idx = idx.strip() |
| if (idx.startswith("'") and idx.endswith("'")) or (idx.startswith('"') and idx.endswith('"')): |
| idx = idx[1:-1] |
| elif idx.isdigit(): |
| idx = int(idx) |
| return parent, idx |
| return cur, leaf |
|
|
|
|
| |
| |
| |
| |
| class CodeBlock: |
| """ |
| Parameters |
| ---------- |
| name : str |
| ้ป่พๅ๏ผๆฅๅฟใ่ฐ่ฏๅๅฅฝ๏ผ |
| func : Callable[[dict], Any] |
| ๆฎ้ๅๆญฅๅฝๆฐ๏ผ่พๅ
ฅ cfg ๅญๅ
ธ |
| """ |
|
|
| def __init__(self, name: str, func: Callable[[Dict[str, Any]], Any]): |
| self.name = name |
| self._func = func |
|
|
| def run(self, cfg: Dict[str, Any]) -> Any: |
| """ๅๆญฅๆง่กๅฐ่ฃ
็ๅฝๆฐใ""" |
| return self._func(cfg) |
|
|
| def __call__(self, cfg: Dict[str, Any]) -> Any: |
| return self.run(cfg) |
|
|
| def __repr__(self): |
| return f"<CodeBlock {self.name} (sync)>" |
|
|
|
|
|
|
|
|
| |
| |
| |
| class BaseCodeBlockOptimizer(abc.ABC): |
| """ |
| Abstract optimiser that: |
| โข performs sequential trials |
| โข writes sampled cfg back to runtime via PromptRegistry |
| โข validates that registered names appear in CodeBlock signature |
| """ |
|
|
| def __init__(self, |
| registry: PromptRegistry, |
| metric: str, |
| maximize: bool = True, |
| max_trials: int = 30): |
| self.registry = registry |
| self.metric = metric |
| self.maximize = maximize |
| self.max_trials = max_trials |
|
|
| @abc.abstractmethod |
| def sample_cfg(self) -> Dict[str, Any]: |
| """Return a cfg dict (may include subset of registry names).""" |
|
|
| @abc.abstractmethod |
| def update(self, cfg: Dict[str, Any], score: float): |
| """Update internal optimiser state.""" |
|
|
| def _apply_cfg(self, cfg: Dict[str, Any]): |
| for k, v in cfg.items(): |
| if k in self.registry.fields: |
| self.registry.set(k, v) |
|
|
| def _check_codeblock_compat(self, code_block: CodeBlock): |
| sig = inspect.signature(code_block._func) |
| params = sig.parameters.values() |
|
|
| has_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params) |
| accepts_cfg_dict = "cfg" in sig.parameters |
|
|
| if has_kwargs or accepts_cfg_dict: |
| return |
|
|
| allowed_keys = set(sig.parameters) |
| unknown = set(self.registry.names()) - allowed_keys |
| if unknown: |
| import warnings |
| warnings.warn(f"PromptRegistry fields {unknown} are not present in " |
| f"{code_block.name}() signature; they will be ignored.") |
|
|
| def run(self, |
| code_block: CodeBlock, |
| evaluator: Callable[[Dict[str, Any], Any], float] |
| ) -> Tuple[Dict[str, Any], List[Tuple[Dict[str, Any], float]]]: |
|
|
| self._check_codeblock_compat(code_block) |
|
|
| best_cfg, best_score = None, -float("inf") if self.maximize else float("inf") |
| history: List[Tuple[Dict[str, Any], float]] = [] |
|
|
| for _ in range(self.max_trials): |
| cfg = self.sample_cfg() |
| self._apply_cfg(cfg) |
| result = code_block.run(cfg) |
| score = evaluator(cfg, result) |
| self.update(cfg, score) |
|
|
| history.append((cfg, score)) |
| better = score > best_score if self.maximize else score < best_score |
| if better: |
| best_cfg, best_score = cfg, score |
|
|
| return best_cfg, history |
|
|
|
|
|
|
| |
| |
| |
| def bind_cfg(obj: Any, cfg: Dict[str, Any]) -> None: |
| """Recursively write *cfg* values into (potentially nested) attributes |
| of *obj*. Key like "a.b.c" becomes obj.a.b.c = value. |
| """ |
| for key, val in cfg.items(): |
| parts = key.split(".") |
| cur = obj |
| for part in parts[:-1]: |
| cur = getattr(cur, part) |
| setattr(cur, parts[-1], val) |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| @dataclass |
| class Sampler: |
| temperature: float = 0.7 |
| top_p: float = 0.9 |
|
|
| class Workflow: |
|
|
| def __init__(self): |
| self.system_prompt = "You are a helpful assistant." |
| self.few_shot = "Q: 1+1=?\nA: 2" |
| self.sampler = Sampler() |
|
|
| |
| def execute(self): |
| |
| pass |
|
|
| def run(self): |
| prompt = f"{self.system_prompt}\n{self.few_shot}\nUser: Hi" |
| return {"prompt": prompt, "score": random.uniform(0, 1)} |
|
|
|
|
| |
| class RandomSearchOptimizer(BaseCodeBlockOptimizer): |
| def sample_cfg(self) -> Dict[str, Any]: |
| return { |
| "sampler_temperature": random.uniform(0.3, 1.3), |
| "sampler_top_p": random.uniform(0.5, 1.0), |
| "sys_prompt": random.choice([ |
| "You are a helpful assistant.", |
| "You are a super-concise assistant." |
| ]), |
| } |
|
|
| def update(self, cfg, score): |
| pass |
|
|
|
|
| class GreedyLoggerOptimizer(BaseCodeBlockOptimizer): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.best = None |
| self.best_score = -float("inf") if self.maximize else float("inf") |
|
|
| def sample_cfg(self): |
| return { |
| "sampler_temperature": random.uniform(0.3, 1.3), |
| "sampler_top_p": random.uniform(0.5, 1.0), |
| "sys_prompt": random.choice([ |
| "You are a helpful assistant.", |
| "You are a super-concise assistant." |
| ]), |
| } |
|
|
| def update(self, cfg, score): |
| if (self.maximize and score > self.best_score) or (not self.maximize and score < self.best_score): |
| self.best = cfg |
| self.best_score = score |
| print(f"[New Best] score={score:.3f} cfg={cfg}") |
|
|
|
|
|
|
| |
| def main(): |
| flow = Workflow() |
|
|
| registry = PromptRegistry() |
| registry.register_path(flow, "system_prompt", name="sys_prompt") |
| registry.register_path(flow, "sampler.temperature") |
| registry.register_path(flow, "sampler.top_p") |
|
|
| code_block = CodeBlock("run_workflow", lambda cfg: flow.run()) |
|
|
| def evaluator(cfg, result) -> float: |
| return result["score"] |
|
|
| opt = RandomSearchOptimizer(registry, metric="score", max_trials=10) |
| best_cfg, history = opt.run(code_block, evaluator) |
|
|
| print("\n=== Trial history ===") |
| for i, (cfg, score) in enumerate(history, 1): |
| print(f"{i:02d}: score={score:.3f}, cfg={cfg}") |
|
|
| print("\n=== Best ===") |
| print(best_cfg) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|