Causal Discovery Algorithm Selection Meta-Learner
A meta-learning system that predicts the top-3 best causal discovery algorithms for any discrete observational dataset, based on dataset meta-features.
🎯 What it Does
Given a new discrete dataset (pandas DataFrame), the system:
- Extracts 34 meta-features (entropy, mutual information, chi² statistics, CI test probes, etc.)
- Predicts normalized SHD for each of 9 algorithms via trained models
- Ranks and returns the top-3 algorithms expected to produce the most accurate CPDAG
📊 Performance (Leave-One-Network-Out Cross-Validation)
Best Model: Pairwise-GBM Ranking
| Metric | Value |
|---|---|
| Top-3 Hit Rate | 71.3% (true best algorithm is in predicted top-3) |
| Mean Regret | 0.011 (tiny SHD gap vs oracle selection) |
| Median Regret | 0.000 (majority of predictions are perfect) |
Model Comparison (178 configs, 14 networks + augmented)
| Model | Top-3 Hit Rate | NDCG@3 | Mean Regret |
|---|---|---|---|
| Pairwise-GBM | 71.3% | — | 0.011 |
| GBM-300-lr01 | 67.4% | 0.957 | 0.011 |
| RF-200 | 66.9% | 0.961 | 0.007 |
| RF-500 | 66.3% | 0.962 | 0.007 |
| GBM-500-lr05 | 65.2% | 0.948 | 0.013 |
Progression
| Stage | Configs | Networks | Top-3 Hit Rate |
|---|---|---|---|
| Initial (small nets) | 65 | 4 | 68.2% |
| All 14 networks | 122 | 14 | 70.5% |
| + Data augmentation | 178 | 14+aug | 71.3% |
🧪 Algorithm Pool (9 algorithms)
| Algorithm | Family | Library | Output | Wins |
|---|---|---|---|---|
| GES | Score-based | causal-learn | CPDAG | 47% |
| PC | Constraint-based | causal-learn | CPDAG | 32% |
| FCI | Constraint-based | causal-learn | PAG | 8% |
| K2 | Score-based | pgmpy | DAG | 6% |
| HC | Score-based (greedy) | pgmpy | DAG | 3% |
| Tabu | Score-based (meta) | pgmpy | DAG | 2% |
| GRaSP | Permutation-based | causal-learn | CPDAG | 1% |
| BOSS | Permutation-based | causal-learn | CPDAG | 1% |
| MMHC | Hybrid | pgmpy | DAG | <1% |
🔬 Key Insight: Dependency Parsing Connection
This project was inspired by a structural parallel between NLP dependency parsing and causal discovery:
- Both predict directed graphs over nodes (words/variables)
- Both have ground-truth annotations (treebanks/bnlearn networks)
- Both use arc-level evaluation (UAS/LAS ↔ SHD/F1)
The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection.
Top Predictive Meta-Features
n_variables(30%) — network size (how many nodes in the graph)max_pairwise_MI(24%) — strongest pairwise dependency (≈ biaffine arc score)max_cramers_v(8%) — strongest association strengthmax_entropy(7%) — variable complexity
Three Ideas Borrowed from Parsing
- Biaffine-style pairwise features: MI and Cramér's V between all variable pairs = parsing's arc scores
- Pairwise ranking (our best model): For each algorithm pair (A,B), predict which wins → count wins to rank. Inspired by pairwise tournament-style parser selection
- Cross-domain transfer: Train on well-characterized bnlearn networks → predict on new unseen datasets (= cross-lingual parser transfer)
🚀 Quick Start
from causal_selection.meta_learner.predictor import predict_best_algorithms
import pandas as pd
# Load your discrete dataset
df = pd.read_csv("my_discrete_data.csv")
# Get top-3 recommendations
result = predict_best_algorithms(df, k=3)
# Prints ranked algorithms with predicted accuracy and confidence
📁 Project Structure
causal_selection/
├── data/
│ ├── generator.py # Load bnlearn networks, sample data, DAG→CPDAG
│ ├── bif_files/ # 14 bnlearn BIF files (asia through win95pts)
│ └── results/ # Benchmark CSVs: meta-features, SHD matrices
├── discovery/
│ ├── algorithms.py # 9 algorithm adapters with timeout handling
│ └── evaluator.py # SHD, F1, Precision, Recall computation
├── features/
│ └── extractor.py # 34 meta-features across 5 tiers
├── meta_learner/
│ ├── trainer.py # Multi-Output RF/GBM + LONO-CV evaluation
│ └── predictor.py # Inference: dataset → top-3 prediction
├── models/
│ ├── meta_learner.pkl # Trained GBM (multi-output fallback)
│ ├── pairwise_model.pkl # Pairwise ranking GBM (best model)
│ └── scaler.pkl # Feature scaler
├── benchmark.py # Full benchmark orchestration
├── run_benchmark.py # Resumable benchmark runner
└── augment_and_improve.py # Data augmentation + model improvement
📈 Benchmark Data
- 14 bnlearn networks: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts
- 178 dataset configs: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection)
- 1,600+ algorithm runs: 9 algorithms × 178 configs with per-algorithm timeout
Data Augmentation Strategies
- Variable subsampling: Drop 20-40% of variables to create virtual sub-networks
- Sample-size variation: Generate N=300, 750, 1500, 3000 for each network
- Noise injection: Randomly flip 5-10% of categorical values
🔧 Dependencies
causal-learn>=0.1.4
pgmpy>=0.1.25
scikit-learn>=1.8
pandas
numpy
scipy
joblib
📚 References
- Causal-Copilot (arxiv:2504.13263) — Closest existing algorithm selection system
- AVICI (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture)
- CauScale (arxiv:2602.08629) — Scalable neural causal discovery
- Dozat & Manning (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing
- TreeCRF (arxiv:2005.00975) — Global structural training loss for parsing
- SATzilla (arxiv:1401.2474) — Algorithm selection via meta-learning
- bnlearn (bnlearn.com) — Bayesian network benchmark repository
🔮 Future Work (Phase 2)
- Biaffine neural encoder: Pre-train a neural feature extractor that learns variable-pair "arc scores"
- Portfolio regret loss (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE
- Hyperparameter co-selection: Predict not just which algorithm but optimal hyperparameters (CASH)
- Ensemble prediction: Run top-3 and vote on edges across their CPDAGs
License
MIT