| """ |
| Data augmentation and model improvement: |
| 1. Variable subsampling: drop random subsets of variables from existing datasets |
| 2. Hyperparameter tuning for the meta-learner |
| 3. Pairwise ranking approach |
| """ |
| import os |
| import sys |
| import numpy as np |
| import pandas as pd |
| import json |
| import logging |
| import warnings |
| from itertools import combinations |
|
|
| warnings.filterwarnings('ignore') |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| sys.path.insert(0, '/app') |
| from causal_selection.data.generator import ( |
| load_bn_model, get_true_dag_adjmat, dag_to_cpdag, sample_dataset, |
| ALL_NETWORKS, get_network_tier |
| ) |
| from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL |
| from causal_selection.discovery.evaluator import evaluate_algorithm_result |
| from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES |
| from causal_selection.meta_learner.trainer import ( |
| load_meta_dataset, train_meta_learner, evaluate_lono_cv, |
| get_feature_importance, save_model, ALGO_NAMES, RESULTS_DIR |
| ) |
|
|
| from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor |
| from sklearn.multioutput import MultiOutputRegressor |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.model_selection import cross_val_score |
| import joblib |
|
|
|
|
| def augment_variable_subsampling(networks=None, n_augments_per_net=3, |
| drop_frac=0.3, n_samples=1000, seed_base=100): |
| """Create augmented datasets by dropping random subsets of variables. |
| |
| This creates new 'virtual networks' with different structural properties. |
| Only works for networks with >10 variables (need enough remaining vars). |
| """ |
| if networks is None: |
| networks = [n for n in ALL_NETWORKS if n not in ['cancer', 'earthquake', 'survey']] |
| |
| augmented_features = [] |
| augmented_shds = [] |
| augmented_nshds = [] |
| augmented_configs = [] |
| |
| for net_name in networks: |
| try: |
| model = load_bn_model(net_name) |
| true_dag, node_names = get_true_dag_adjmat(model) |
| n_vars = len(node_names) |
| |
| if n_vars < 8: |
| logger.info(f"Skipping {net_name} ({n_vars} vars): too few for subsampling") |
| continue |
| |
| n_to_keep = max(5, int(n_vars * (1 - drop_frac))) |
| tier = get_network_tier(net_name) |
| timeout = {'small': 60, 'medium': 120, 'large': 180}[tier] |
| |
| for aug_idx in range(n_augments_per_net): |
| rng = np.random.RandomState(seed_base + aug_idx) |
| |
| |
| keep_idx = sorted(rng.choice(n_vars, n_to_keep, replace=False)) |
| |
| |
| sub_dag = true_dag[np.ix_(keep_idx, keep_idx)] |
| sub_cpdag = dag_to_cpdag(sub_dag) |
| sub_names = [node_names[i] for i in keep_idx] |
| |
| |
| df_full = sample_dataset(model, n_samples, seed=seed_base + aug_idx) |
| df_sub = df_full[sub_names].copy() |
| df_sub.columns = [f'X{i}' for i in range(len(sub_names))] |
| |
| logger.info(f" Augment {net_name} #{aug_idx}: {n_vars}->{n_to_keep} vars") |
| |
| |
| features = extract_all_features(df_sub, n_probe_triplets=50) |
| |
| |
| shd_row = {} |
| nshd_row = {} |
| n_sub = len(sub_names) |
| max_shd = n_sub * (n_sub - 1) // 2 |
| |
| for algo_name in ALGO_NAMES: |
| result = run_algorithm(algo_name, df_sub, timeout_sec=timeout) |
| metrics = evaluate_algorithm_result(result, sub_cpdag) |
| shd_row[algo_name] = metrics['shd'] |
| nshd_row[algo_name] = metrics['normalized_shd'] |
| |
| s = metrics['status'] |
| if s == 'success': |
| logger.info(f" {algo_name:12s}: SHD={metrics['shd']:3d} t={metrics['runtime']:.1f}s") |
| else: |
| logger.info(f" {algo_name:12s}: {s}") |
| |
| feat_row = {name: features.get(name, 0.0) for name in FEATURE_NAMES} |
| augmented_features.append(feat_row) |
| augmented_shds.append(shd_row) |
| augmented_nshds.append(nshd_row) |
| augmented_configs.append({ |
| 'network': f'{net_name}_sub{aug_idx}', |
| 'n_samples': n_samples, |
| 'seed': seed_base + aug_idx, |
| 'n_variables': n_to_keep, |
| 'n_true_edges': int(((sub_cpdag + sub_cpdag.T) > 0).sum() // 2), |
| }) |
| |
| except Exception as e: |
| logger.error(f"Augmentation failed for {net_name}: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| return augmented_features, augmented_shds, augmented_nshds, augmented_configs |
|
|
|
|
| def hyperparameter_sweep(): |
| """Try different model configs and evaluate.""" |
| X, Y_shd, Y_nshd, configs = load_meta_dataset() |
| |
| print(f"Data: {X.shape[0]} samples, {X.shape[1]} features, {Y_nshd.shape[1]} algorithms") |
| print(f"Networks: {sorted(configs.network.unique())}") |
| |
| model_configs = [ |
| ('RF-200', 'rf', {'n_estimators': 200}), |
| ('RF-500', 'rf', {'n_estimators': 500}), |
| ('RF-200-d10', 'rf', {'n_estimators': 200, 'max_depth': 10}), |
| ('RF-200-d5', 'rf', {'n_estimators': 200, 'max_depth': 5}), |
| ('RF-200-leaf5', 'rf', {'n_estimators': 200, 'min_samples_leaf': 5}), |
| ('GBM-200', 'gbm', {'n_estimators': 200, 'max_depth': 5, 'learning_rate': 0.1}), |
| ('GBM-500', 'gbm', {'n_estimators': 500, 'max_depth': 3, 'learning_rate': 0.05}), |
| ('GBM-200-lr01', 'gbm', {'n_estimators': 200, 'max_depth': 4, 'learning_rate': 0.01}), |
| ] |
| |
| print(f"\n{'Model':20s} {'Top3 Hit':>10s} {'NDCG@3':>8s} {'Regret':>8s} {'Overlap':>8s}") |
| print("-" * 60) |
| |
| best_hit = 0 |
| best_name = None |
| best_type = None |
| best_kwargs = None |
| |
| for name, mtype, kwargs in model_configs: |
| results = evaluate_lono_cv(X, Y_nshd, configs, model_type=mtype, k=3, **kwargs) |
| o = results['overall'] |
| print(f"{name:20s} {o['top_k_hit_rate']:10.3f} {o['ndcg_at_k']:8.3f} " |
| f"{o['mean_regret']:8.4f} {o['top_k_overlap_rate']:8.3f}") |
| |
| if o['top_k_hit_rate'] > best_hit: |
| best_hit = o['top_k_hit_rate'] |
| best_name = name |
| best_type = mtype |
| best_kwargs = kwargs |
| |
| print(f"\nBest model: {best_name} (hit rate={best_hit:.3f})") |
| |
| |
| model, scaler = train_meta_learner(X, Y_nshd, model_type=best_type, **best_kwargs) |
| save_model(model, scaler) |
| |
| avg_imp, _ = get_feature_importance(model) |
| print("\nTop 10 Features (best model):") |
| for feat, imp in sorted(avg_imp.items(), key=lambda x: -x[1])[:10]: |
| print(f" {feat:30s}: {imp:.4f}") |
| |
| return best_name, best_type, best_kwargs |
|
|
|
|
| if __name__ == '__main__': |
| import sys |
| |
| mode = sys.argv[1] if len(sys.argv) > 1 else 'sweep' |
| |
| if mode == 'augment': |
| |
| feats, shds, nshds, cfgs = augment_variable_subsampling( |
| networks=['asia', 'sachs', 'alarm', 'child', 'insurance', 'water'], |
| n_augments_per_net=2, drop_frac=0.3, n_samples=1000 |
| ) |
| |
| |
| X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset() |
| |
| X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES) |
| Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES) |
| Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES) |
| configs_aug = pd.DataFrame(cfgs) |
| |
| X_all = pd.concat([X_orig, X_aug], ignore_index=True) |
| Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True) |
| Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True) |
| configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True) |
| |
| |
| X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False) |
| Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False) |
| Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False) |
| configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False) |
| |
| print(f"\nAugmented dataset: {len(configs_all)} total configs ({len(configs_orig)} original + {len(configs_aug)} augmented)") |
| |
| elif mode == 'sweep': |
| hyperparameter_sweep() |
| |
| elif mode == 'all': |
| |
| print("=" * 80) |
| print("STEP 1: DATA AUGMENTATION") |
| print("=" * 80) |
| |
| feats, shds, nshds, cfgs = augment_variable_subsampling( |
| networks=['asia', 'sachs', 'alarm', 'child', 'insurance', 'water'], |
| n_augments_per_net=2, drop_frac=0.3, n_samples=1000 |
| ) |
| |
| X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset() |
| X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES) |
| Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES) |
| Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES) |
| configs_aug = pd.DataFrame(cfgs) |
| |
| X_all = pd.concat([X_orig, X_aug], ignore_index=True) |
| Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True) |
| Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True) |
| configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True) |
| |
| X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False) |
| Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False) |
| Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False) |
| configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False) |
| |
| print(f"\nAugmented: {len(configs_all)} configs") |
| |
| print("\n" + "=" * 80) |
| print("STEP 2: HYPERPARAMETER SWEEP") |
| print("=" * 80) |
| hyperparameter_sweep() |
|
|