Oguzz07's picture
Update to GBM-500 best model (70.5% top-3 hit rate)
8744a77 verified
"""
Data augmentation and model improvement:
1. Variable subsampling: drop random subsets of variables from existing datasets
2. Hyperparameter tuning for the meta-learner
3. Pairwise ranking approach
"""
import os
import sys
import numpy as np
import pandas as pd
import json
import logging
import warnings
from itertools import combinations
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
sys.path.insert(0, '/app')
from causal_selection.data.generator import (
load_bn_model, get_true_dag_adjmat, dag_to_cpdag, sample_dataset,
ALL_NETWORKS, get_network_tier
)
from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL
from causal_selection.discovery.evaluator import evaluate_algorithm_result
from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES
from causal_selection.meta_learner.trainer import (
load_meta_dataset, train_meta_learner, evaluate_lono_cv,
get_feature_importance, save_model, ALGO_NAMES, RESULTS_DIR
)
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, ExtraTreesRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
import joblib
def augment_variable_subsampling(networks=None, n_augments_per_net=3,
drop_frac=0.3, n_samples=1000, seed_base=100):
"""Create augmented datasets by dropping random subsets of variables.
This creates new 'virtual networks' with different structural properties.
Only works for networks with >10 variables (need enough remaining vars).
"""
if networks is None:
networks = [n for n in ALL_NETWORKS if n not in ['cancer', 'earthquake', 'survey']] # skip tiny
augmented_features = []
augmented_shds = []
augmented_nshds = []
augmented_configs = []
for net_name in networks:
try:
model = load_bn_model(net_name)
true_dag, node_names = get_true_dag_adjmat(model)
n_vars = len(node_names)
if n_vars < 8:
logger.info(f"Skipping {net_name} ({n_vars} vars): too few for subsampling")
continue
n_to_keep = max(5, int(n_vars * (1 - drop_frac)))
tier = get_network_tier(net_name)
timeout = {'small': 60, 'medium': 120, 'large': 180}[tier]
for aug_idx in range(n_augments_per_net):
rng = np.random.RandomState(seed_base + aug_idx)
# Select random subset of variables
keep_idx = sorted(rng.choice(n_vars, n_to_keep, replace=False))
# Subsample the DAG and recompute CPDAG
sub_dag = true_dag[np.ix_(keep_idx, keep_idx)]
sub_cpdag = dag_to_cpdag(sub_dag)
sub_names = [node_names[i] for i in keep_idx]
# Sample full data then select columns
df_full = sample_dataset(model, n_samples, seed=seed_base + aug_idx)
df_sub = df_full[sub_names].copy()
df_sub.columns = [f'X{i}' for i in range(len(sub_names))]
logger.info(f" Augment {net_name} #{aug_idx}: {n_vars}->{n_to_keep} vars")
# Extract features
features = extract_all_features(df_sub, n_probe_triplets=50)
# Run algorithms on subsampled data
shd_row = {}
nshd_row = {}
n_sub = len(sub_names)
max_shd = n_sub * (n_sub - 1) // 2
for algo_name in ALGO_NAMES:
result = run_algorithm(algo_name, df_sub, timeout_sec=timeout)
metrics = evaluate_algorithm_result(result, sub_cpdag)
shd_row[algo_name] = metrics['shd']
nshd_row[algo_name] = metrics['normalized_shd']
s = metrics['status']
if s == 'success':
logger.info(f" {algo_name:12s}: SHD={metrics['shd']:3d} t={metrics['runtime']:.1f}s")
else:
logger.info(f" {algo_name:12s}: {s}")
feat_row = {name: features.get(name, 0.0) for name in FEATURE_NAMES}
augmented_features.append(feat_row)
augmented_shds.append(shd_row)
augmented_nshds.append(nshd_row)
augmented_configs.append({
'network': f'{net_name}_sub{aug_idx}',
'n_samples': n_samples,
'seed': seed_base + aug_idx,
'n_variables': n_to_keep,
'n_true_edges': int(((sub_cpdag + sub_cpdag.T) > 0).sum() // 2),
})
except Exception as e:
logger.error(f"Augmentation failed for {net_name}: {e}")
import traceback
traceback.print_exc()
return augmented_features, augmented_shds, augmented_nshds, augmented_configs
def hyperparameter_sweep():
"""Try different model configs and evaluate."""
X, Y_shd, Y_nshd, configs = load_meta_dataset()
print(f"Data: {X.shape[0]} samples, {X.shape[1]} features, {Y_nshd.shape[1]} algorithms")
print(f"Networks: {sorted(configs.network.unique())}")
model_configs = [
('RF-200', 'rf', {'n_estimators': 200}),
('RF-500', 'rf', {'n_estimators': 500}),
('RF-200-d10', 'rf', {'n_estimators': 200, 'max_depth': 10}),
('RF-200-d5', 'rf', {'n_estimators': 200, 'max_depth': 5}),
('RF-200-leaf5', 'rf', {'n_estimators': 200, 'min_samples_leaf': 5}),
('GBM-200', 'gbm', {'n_estimators': 200, 'max_depth': 5, 'learning_rate': 0.1}),
('GBM-500', 'gbm', {'n_estimators': 500, 'max_depth': 3, 'learning_rate': 0.05}),
('GBM-200-lr01', 'gbm', {'n_estimators': 200, 'max_depth': 4, 'learning_rate': 0.01}),
]
print(f"\n{'Model':20s} {'Top3 Hit':>10s} {'NDCG@3':>8s} {'Regret':>8s} {'Overlap':>8s}")
print("-" * 60)
best_hit = 0
best_name = None
best_type = None
best_kwargs = None
for name, mtype, kwargs in model_configs:
results = evaluate_lono_cv(X, Y_nshd, configs, model_type=mtype, k=3, **kwargs)
o = results['overall']
print(f"{name:20s} {o['top_k_hit_rate']:10.3f} {o['ndcg_at_k']:8.3f} "
f"{o['mean_regret']:8.4f} {o['top_k_overlap_rate']:8.3f}")
if o['top_k_hit_rate'] > best_hit:
best_hit = o['top_k_hit_rate']
best_name = name
best_type = mtype
best_kwargs = kwargs
print(f"\nBest model: {best_name} (hit rate={best_hit:.3f})")
# Train and save best model
model, scaler = train_meta_learner(X, Y_nshd, model_type=best_type, **best_kwargs)
save_model(model, scaler)
avg_imp, _ = get_feature_importance(model)
print("\nTop 10 Features (best model):")
for feat, imp in sorted(avg_imp.items(), key=lambda x: -x[1])[:10]:
print(f" {feat:30s}: {imp:.4f}")
return best_name, best_type, best_kwargs
if __name__ == '__main__':
import sys
mode = sys.argv[1] if len(sys.argv) > 1 else 'sweep'
if mode == 'augment':
# Run variable subsampling augmentation
feats, shds, nshds, cfgs = augment_variable_subsampling(
networks=['asia', 'sachs', 'alarm', 'child', 'insurance', 'water'],
n_augments_per_net=2, drop_frac=0.3, n_samples=1000
)
# Merge with existing data
X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset()
X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES)
Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES)
Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES)
configs_aug = pd.DataFrame(cfgs)
X_all = pd.concat([X_orig, X_aug], ignore_index=True)
Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True)
Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True)
configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True)
# Save
X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False)
Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False)
Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False)
configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False)
print(f"\nAugmented dataset: {len(configs_all)} total configs ({len(configs_orig)} original + {len(configs_aug)} augmented)")
elif mode == 'sweep':
hyperparameter_sweep()
elif mode == 'all':
# First augment, then sweep
print("=" * 80)
print("STEP 1: DATA AUGMENTATION")
print("=" * 80)
feats, shds, nshds, cfgs = augment_variable_subsampling(
networks=['asia', 'sachs', 'alarm', 'child', 'insurance', 'water'],
n_augments_per_net=2, drop_frac=0.3, n_samples=1000
)
X_orig, Y_shd_orig, Y_nshd_orig, configs_orig = load_meta_dataset()
X_aug = pd.DataFrame(feats, columns=FEATURE_NAMES)
Y_shd_aug = pd.DataFrame(shds, columns=ALGO_NAMES)
Y_nshd_aug = pd.DataFrame(nshds, columns=ALGO_NAMES)
configs_aug = pd.DataFrame(cfgs)
X_all = pd.concat([X_orig, X_aug], ignore_index=True)
Y_shd_all = pd.concat([Y_shd_orig, Y_shd_aug], ignore_index=True)
Y_nshd_all = pd.concat([Y_nshd_orig, Y_nshd_aug], ignore_index=True)
configs_all = pd.concat([configs_orig, configs_aug], ignore_index=True)
X_all.to_csv(os.path.join(RESULTS_DIR, 'meta_features.csv'), index=False)
Y_shd_all.to_csv(os.path.join(RESULTS_DIR, 'shd_matrix.csv'), index=False)
Y_nshd_all.to_csv(os.path.join(RESULTS_DIR, 'normalized_shd_matrix.csv'), index=False)
configs_all.to_csv(os.path.join(RESULTS_DIR, 'configs.csv'), index=False)
print(f"\nAugmented: {len(configs_all)} configs")
print("\n" + "=" * 80)
print("STEP 2: HYPERPARAMETER SWEEP")
print("=" * 80)
hyperparameter_sweep()