""" Data augmentation and model improvement: 1. Variable subsampling: drop random subsets of variables from existing datasets 2. Hyperparameter tuning for the meta-learner 3. Pairwise ranking approach """ import os import sys import numpy as np import pandas as pd import json import logging import warnings from itertools import combinations warnings.filterwarnings('ignore') logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) sys.path.insert(0, '/app') from causal_selection.data.generator import ( load_bn_model, get_true_dag_adjmat, dag_to_cpdag, sample_dataset, ALL_NETWORKS, get_network_tier ) from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL from causal_selection.discovery.evaluator import evaluate_algorithm_result from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES from causal_selection.meta_learner.trainer import ( load_meta_dataset, train_meta_learner, evaluate_lono_cv, get_feature_importance, save_model, ALGO_NAMES, RESULTS_DIR ) from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor from sklearn.multioutput import MultiOutputRegressor from sklearn.preprocessing import StandardScaler from sklearn.model_selection import cross_val_score import joblib def augment_variable_subsampling(networks=None, n_augments_per_net=3, drop_frac=0.3, n_samples=1000, seed_base=100): """Create augmented datasets by dropping random subsets of variables. This creates new 'virtual networks' with different structural properties. Only works for networks with >10 variables (need enough remaining vars). """ if networks is None: networks = [n for n in ALL_NETWORKS if n not in ['cancer', 'earthquake', 'survey']] # skip tiny augmented_features = [] augmented_shds = [] augmented_nshds = [] augmented_configs = [] for net_name in networks: try: model = load_bn_model(net_name) true_dag, node_names = get_true_dag_adjmat(model) n_vars = len(node_names) if n_vars < 8: logger.info(f"Skipping {net_name} ({n_vars} vars): too few for subsampling") continue n_to_keep = max(5, int(n_vars * (1 - drop_frac))) tier = get_network_tier(net_name) timeout = {'small': 60, 'medium': 120, 'large': 180}[tier] for aug_idx in range(n_augments_per_net): rng = np.random.RandomState(seed_base + aug_idx) # Select random subset of variables keep_idx = sorted(rng.choice(n_vars, n_to_keep, replace=False)) # Subsample the DAG and recompute CPDAG sub_dag = true_dag[np.ix_(keep_idx, keep_idx)] sub_cpdag = dag_to_cpdag(sub_dag) sub_names = [node_names[i] for i in keep_idx] # Sample full data then select columns df_full = sample_dataset(model, n_samples, seed=seed_base + aug_idx) df_sub = df_full[sub_names].copy() df_sub.columns = [f'X{i}' for i in range(len(sub_names))] logger.info(f" Augment {net_name} #{aug_idx}: {n_vars}->{n_to_keep} vars") # Extract features features = extract_all_features(df_sub, n_probe_triplets=50) # Run algorithms on subsampled data shd_row = {} nshd_row = {} n_sub = len(sub_names) max_shd = n_sub * (n_sub - 1) // 2 for algo_name in ALGO_NAMES: result = run_algorithm(algo_name, df_sub, timeout_sec=timeout) metrics = evaluate_algorithm_result(result, sub_cpdag) shd_row[algo_name] = metrics['shd'] nshd_row[algo_name] = metrics['normalized_shd'] s = metrics['status'] if s == 'success': logger.info(f" {algo_name:12s}: SHD={metrics['shd']:3d} t={metrics['runtime']:.1f}s") else: logger.info(f" {algo_name:12s}: {s}") feat_row = {name: features.get(name, 0.0) for name in FEATURE_NAMES} augmented_features.append(feat_row) augmented_shds.append(shd_row) augmented_nshds.append(nshd_row) augmented_configs.append({ 'network': f'{net_name}_sub{aug_idx}', 'n_samples': n_samples, 'seed': seed_base + aug_idx, 'n_variables': n_to_keep, 'n_true_edges': int(((sub_cpdag + sub_cpdag.T) > 0).sum() // 2), }) except Exception as e: logger.error(f"Augmentation failed for {net_name}: {e}") import traceback traceback.print_exc() return augmented_features, augmented_shds, augmented_nshds, augmented_configs def hyperparameter_sweep(): """Try different model configs and evaluate.""" X, Y_shd, Y_nshd, configs = load_meta_dataset() print(f"Data: {X.shape[0]} samples, {X.shape[1]} features, {Y_nshd.shape[1]} algorithms") print(f"Networks: {sorted(configs.network.unique())}") model_configs = [ ('RF-200', 'rf', {'n_estimators': 200}), ('RF-500', 'rf', {'n_estimators': 500}), ('RF-200-d10', 'rf', {'n_estimators': 200, 'max_depth': 10}), ('RF-200-d5', 'rf', {'n_estimators': 200, 'max_depth': 5}), ('RF-200-leaf5', 'rf', {'n_estimators': 200, 'min_samples_leaf': 5}), ('GBM-200', 'gbm', {'n_estimators': 200, 'max_depth': 5, 'learning_rate': 0.1}), ('GBM-500', 'gbm', {'n_estimators': 500, 'max_depth': 3, 'learning_rate': 0.05}), ('GBM-200-lr01', 'gbm', {'n_estimators': 200, 'max_depth': 4, 'learning_rate': 0.01}), ] print(f"\n{'Model':20s} {'Top3 Hit':>10s} {'NDCG@3':>8s} {'Regret':>8s} {'Overlap':>8s}") print("-" * 60) best_hit = 0 best_name = None best_type = None best_kwargs = None for name, mtype, kwargs in model_configs: results = evaluate_lono_cv(X, Y_nshd, configs, model_type=mtype, k=3, **kwargs) o = results['overall'] print(f"{name:20s} {o['top_k_hit_rate']:10.3f} {o['ndcg_at_k']:8.3f} " f"{o['mean_regret']:8.4f} {o['top_k_overlap_rate']:8.3f}") if o['top_k_hit_rate'] > best_hit: best_hit = o['top_k_hit_rate'] best_name = name best_type = mtype best_kwargs = kwargs print(f"\nBest model: {best_name} (hit rate={best_hit:.3f})") # Train and save best model model, scaler = train_meta_learner(X, Y_nshd, model_type=best_type, **best_kwargs) save_model(model, scaler) avg_imp, _ = get_feature_importance(model) print("\nTop 10 Features (best model):") for feat, imp in sorted(avg_imp.items(), key=lambda x: -x[1])[:10]: print(f" {feat:30s}: {imp:.4f}") return best_name, best_type, best_kwargs if __name__ == '__main__': import sys mode = sys.argv[1] if len(sys.argv) > 1 else 'sweep' if mode == 'augment': # Run variable subsampling augmentation feats, shds, nshds, cfgs = augment_variable_subsampling( networks=['asia', 'sachs', 'alarm', 'child', 'insurance', 'water'], n_augments_per_net=2, drop_frac=0.3, n_samples=1000 ) # Merge with existing data X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset() X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES) Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES) Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES) configs_aug = pd.DataFrame(cfgs) X_all = pd.concat([X_orig, X_aug], ignore_index=True) Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True) Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True) configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True) # Save X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False) Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False) Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False) configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False) print(f"\nAugmented dataset: {len(configs_all)} total configs ({len(configs_orig)} original + {len(configs_aug)} augmented)") elif mode == 'sweep': hyperparameter_sweep() elif mode == 'all': # First augment, then sweep print("=" * 80) print("STEP 1: DATA AUGMENTATION") print("=" * 80) feats, shds, nshds, cfgs = augment_variable_subsampling( networks=['asia', 'sachs', 'alarm', 'child', 'insurance', 'water'], n_augments_per_net=2, drop_frac=0.3, n_samples=1000 ) X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset() X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES) Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES) Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES) configs_aug = pd.DataFrame(cfgs) X_all = pd.concat([X_orig, X_aug], ignore_index=True) Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True) Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True) configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True) X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False) Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False) Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False) configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False) print(f"\nAugmented: {len(configs_all)} configs") print("\n" + "=" * 80) print("STEP 2: HYPERPARAMETER SWEEP") print("=" * 80) hyperparameter_sweep()