| """ |
| threshold_optimizer.py β Post-training threshold calibration tool. |
| |
| Run this standalone to re-optimize the probability threshold on new data |
| WITHOUT retraining the model. Useful for: |
| - Adapting to regime changes without full retraining |
| - Testing different optimization objectives |
| - Out-of-sample threshold validation |
| |
| The threshold search maximizes expectancy or Sharpe over a held-out dataset. |
| |
| Usage: |
| python threshold_optimizer.py --symbols BTC-USDT ETH-USDT --bars 200 |
| python threshold_optimizer.py --objective sharpe |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from ml_config import ( |
| THRESHOLD_PATH, |
| THRESHOLD_MIN, |
| THRESHOLD_MAX, |
| THRESHOLD_STEPS, |
| THRESHOLD_OBJECTIVE, |
| TARGET_RR, |
| ROUND_TRIP_COST, |
| FEATURE_COLUMNS, |
| ML_DIR, |
| ) |
| from ml_filter import TradeFilter |
| from feature_builder import build_feature_dict, validate_features |
| from train import build_dataset |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
|
|
|
| def compute_threshold_curve( |
| probs: np.ndarray, |
| y_true: np.ndarray, |
| rr: float = TARGET_RR, |
| cost: float = ROUND_TRIP_COST, |
| ) -> pd.DataFrame: |
| """ |
| Sweep threshold grid and compute metrics at each threshold. |
| Returns DataFrame for analysis and plotting. |
| """ |
| thresholds = np.linspace(THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS) |
| records = [] |
|
|
| for t in thresholds: |
| mask = probs >= t |
| n = int(mask.sum()) |
| if n < 5: |
| records.append({ |
| "threshold": t, "n_trades": n, |
| "win_rate": np.nan, "expectancy": np.nan, |
| "sharpe": np.nan, "precision": np.nan, |
| "coverage": 0.0, |
| }) |
| continue |
|
|
| y_f = y_true[mask] |
| wr = float(y_f.mean()) |
| exp = wr * rr - (1 - wr) * 1.0 - cost |
| pnl = np.where(y_f == 1, rr, -1.0) - cost |
| sh = (pnl.mean() / pnl.std() * np.sqrt(252)) if pnl.std() > 1e-9 else 0.0 |
| cov = n / len(y_true) |
|
|
| records.append({ |
| "threshold": round(t, 4), |
| "n_trades": n, |
| "win_rate": round(wr, 4), |
| "expectancy": round(exp, 4), |
| "sharpe": round(sh, 4), |
| "precision": round(wr, 4), |
| "coverage": round(cov, 4), |
| }) |
|
|
| return pd.DataFrame(records) |
|
|
|
|
| def find_optimal_threshold( |
| curve: pd.DataFrame, |
| objective: str = THRESHOLD_OBJECTIVE, |
| min_trades: int = 20, |
| ) -> float: |
| valid = curve[curve["n_trades"] >= min_trades].dropna(subset=[objective]) |
| if valid.empty: |
| logger.warning("No valid threshold found β using default 0.55") |
| return 0.55 |
| best_row = valid.loc[valid[objective].idxmax()] |
| return float(best_row["threshold"]) |
|
|
|
|
| def plot_threshold_curves(curve: pd.DataFrame, optimal: float, save_path: Path): |
| fig, axes = plt.subplots(2, 2, figsize=(12, 8)) |
| fig.suptitle("Threshold Optimization", fontsize=14, fontweight="bold") |
|
|
| metrics = ["expectancy", "sharpe", "win_rate", "n_trades"] |
| titles = ["Expectancy per Trade", "Annualized Sharpe", "Win Rate", "# Trades"] |
|
|
| for ax, metric, title in zip(axes.flatten(), metrics, titles): |
| valid = curve.dropna(subset=[metric]) |
| ax.plot(valid["threshold"], valid[metric], lw=2, color="#1a6bff") |
| ax.axvline(optimal, color="orange", linestyle="--", lw=1.5, label=f"Optimal={optimal:.3f}") |
| ax.axhline(0, color="gray", linestyle=":", lw=0.8) |
| ax.set_title(title, fontsize=11) |
| ax.set_xlabel("Threshold") |
| ax.legend(fontsize=9) |
| ax.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(save_path, dpi=120, bbox_inches="tight") |
| plt.close() |
| logger.info(f"Threshold curve plot saved β {save_path}") |
|
|
|
|
| def main(args): |
| trade_filter = TradeFilter.load_or_none() |
| if trade_filter is None: |
| logger.error("No trained model found. Run train.py first.") |
| sys.exit(1) |
|
|
| symbols = args.symbols or ["BTC-USDT", "ETH-USDT", "SOL-USDT", "BNB-USDT"] |
| dataset = build_dataset(symbols, bars=args.bars) |
|
|
| X = dataset[FEATURE_COLUMNS].values.astype(np.float64) |
| y = dataset["label"].values.astype(np.int32) |
|
|
| feature_dicts = [ |
| {k: float(row[k]) for k in FEATURE_COLUMNS} |
| for _, row in dataset[FEATURE_COLUMNS].iterrows() |
| ] |
| probs = trade_filter.predict_batch(feature_dicts) |
|
|
| logger.info(f"Generated {len(probs)} predictions | mean_prob={probs.mean():.4f}") |
|
|
| curve = compute_threshold_curve(probs, y) |
| optimal = find_optimal_threshold(curve, objective=args.objective) |
| best_row = curve[curve["threshold"].round(4) == round(optimal, 4)].iloc[0] |
|
|
| logger.info(f"\n=== THRESHOLD OPTIMIZATION RESULT ===") |
| logger.info(f" Objective: {args.objective}") |
| logger.info(f" Optimal threshold: {optimal:.4f}") |
| logger.info(f" Win rate: {best_row['win_rate']:.4f}") |
| logger.info(f" Expectancy: {best_row['expectancy']:.4f}") |
| logger.info(f" Sharpe: {best_row['sharpe']:.4f}") |
| logger.info(f" # Trades: {int(best_row['n_trades'])}") |
| logger.info(f" Coverage: {best_row['coverage']:.2%}") |
|
|
| |
| ML_DIR.mkdir(parents=True, exist_ok=True) |
| thresh_data = { |
| "threshold": optimal, |
| "objective": args.objective, |
| "win_rate_at_threshold": float(best_row["win_rate"]), |
| "expectancy_at_threshold": float(best_row["expectancy"]), |
| "sharpe_at_threshold": float(best_row["sharpe"]), |
| "n_trades_at_threshold": int(best_row["n_trades"]), |
| } |
| with open(THRESHOLD_PATH, "w") as f: |
| json.dump(thresh_data, f, indent=2) |
| logger.info(f"Threshold updated β {THRESHOLD_PATH}") |
|
|
| |
| curve_path = ML_DIR / "threshold_curve.csv" |
| curve.to_csv(curve_path, index=False) |
|
|
| |
| plot_path = ML_DIR / "threshold_curve.png" |
| try: |
| plot_threshold_curves(curve, optimal, plot_path) |
| except Exception as e: |
| logger.warning(f"Plot failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Optimize probability threshold") |
| parser.add_argument("--symbols", nargs="+", default=None) |
| parser.add_argument("--bars", type=int, default=200) |
| parser.add_argument("--objective", choices=["expectancy", "sharpe", "win_rate"], default=THRESHOLD_OBJECTIVE) |
| args = parser.parse_args() |
| main(args) |
|
|