| # Causal Discovery Algorithm Selection Meta-Learner |
|
|
| A meta-learning system that predicts the **top-3 best causal discovery algorithms** for any discrete observational dataset, based on dataset meta-features. |
|
|
| ## 🎯 What it Does |
|
|
| Given a new discrete dataset (pandas DataFrame), the system: |
| 1. **Extracts 34 meta-features** (entropy, mutual information, chi² statistics, CI test probes, etc.) |
| 2. **Predicts normalized SHD** for each of 9 algorithms via trained models |
| 3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG |
|
|
| ## 📊 Performance (Leave-One-Network-Out Cross-Validation) |
|
|
| ### Best Model: Pairwise-GBM Ranking |
|
|
| | Metric | Value | |
| |--------|-------| |
| | **Top-3 Hit Rate** | **71.3%** (true best algorithm is in predicted top-3) | |
| | **Mean Regret** | **0.011** (tiny SHD gap vs oracle selection) | |
| | **Median Regret** | **0.000** (majority of predictions are perfect) | |
|
|
| ### Model Comparison (178 configs, 14 networks + augmented) |
|
|
| | Model | Top-3 Hit Rate | NDCG@3 | Mean Regret | |
| |-------|---------------|--------|-------------| |
| | **Pairwise-GBM** | **71.3%** | — | 0.011 | |
| | GBM-300-lr01 | 67.4% | 0.957 | 0.011 | |
| | RF-200 | 66.9% | 0.961 | 0.007 | |
| | RF-500 | 66.3% | 0.962 | 0.007 | |
| | GBM-500-lr05 | 65.2% | 0.948 | 0.013 | |
|
|
| ### Progression |
|
|
| | Stage | Configs | Networks | Top-3 Hit Rate | |
| |-------|---------|----------|---------------| |
| | Initial (small nets) | 65 | 4 | 68.2% | |
| | All 14 networks | 122 | 14 | 70.5% | |
| | + Data augmentation | 178 | 14+aug | **71.3%** | |
|
|
| ## 🧪 Algorithm Pool (9 algorithms) |
|
|
| | Algorithm | Family | Library | Output | Wins | |
| |-----------|--------|---------|--------|------| |
| | **GES** | Score-based | causal-learn | CPDAG | 47% | |
| | **PC** | Constraint-based | causal-learn | CPDAG | 32% | |
| | **FCI** | Constraint-based | causal-learn | PAG | 8% | |
| | **K2** | Score-based | pgmpy | DAG | 6% | |
| | **HC** | Score-based (greedy) | pgmpy | DAG | 3% | |
| | **Tabu** | Score-based (meta) | pgmpy | DAG | 2% | |
| | **GRaSP** | Permutation-based | causal-learn | CPDAG | 1% | |
| | **BOSS** | Permutation-based | causal-learn | CPDAG | 1% | |
| | **MMHC** | Hybrid | pgmpy | DAG | <1% | |
|
|
| ## 🔬 Key Insight: Dependency Parsing Connection |
|
|
| This project was inspired by a structural parallel between **NLP dependency parsing** and **causal discovery**: |
| - Both predict **directed graphs** over nodes (words/variables) |
| - Both have **ground-truth annotations** (treebanks/bnlearn networks) |
| - Both use **arc-level evaluation** (UAS/LAS ↔ SHD/F1) |
|
|
| The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection. |
|
|
| ### Top Predictive Meta-Features |
| 1. `n_variables` (30%) — network size (how many nodes in the graph) |
| 2. `max_pairwise_MI` (24%) — strongest pairwise dependency (≈ biaffine arc score) |
| 3. `max_cramers_v` (8%) — strongest association strength |
| 4. `max_entropy` (7%) — variable complexity |
|
|
| ### Three Ideas Borrowed from Parsing |
| 1. **Biaffine-style pairwise features**: MI and Cramér's V between all variable pairs = parsing's arc scores |
| 2. **Pairwise ranking** (our best model): For each algorithm pair (A,B), predict which wins → count wins to rank. Inspired by pairwise tournament-style parser selection |
| 3. **Cross-domain transfer**: Train on well-characterized bnlearn networks → predict on new unseen datasets (= cross-lingual parser transfer) |
|
|
| ## 🚀 Quick Start |
|
|
| ```python |
| from causal_selection.meta_learner.predictor import predict_best_algorithms |
| import pandas as pd |
| |
| # Load your discrete dataset |
| df = pd.read_csv("my_discrete_data.csv") |
| |
| # Get top-3 recommendations |
| result = predict_best_algorithms(df, k=3) |
| # Prints ranked algorithms with predicted accuracy and confidence |
| ``` |
|
|
| ## 📁 Project Structure |
|
|
| ``` |
| causal_selection/ |
| ├── data/ |
| │ ├── generator.py # Load bnlearn networks, sample data, DAG→CPDAG |
| │ ├── bif_files/ # 14 bnlearn BIF files (asia through win95pts) |
| │ └── results/ # Benchmark CSVs: meta-features, SHD matrices |
| ├── discovery/ |
| │ ├── algorithms.py # 9 algorithm adapters with timeout handling |
| │ └── evaluator.py # SHD, F1, Precision, Recall computation |
| ├── features/ |
| │ └── extractor.py # 34 meta-features across 5 tiers |
| ├── meta_learner/ |
| │ ├── trainer.py # Multi-Output RF/GBM + LONO-CV evaluation |
| │ └── predictor.py # Inference: dataset → top-3 prediction |
| ├── models/ |
| │ ├── meta_learner.pkl # Trained GBM (multi-output fallback) |
| │ ├── pairwise_model.pkl # Pairwise ranking GBM (best model) |
| │ └── scaler.pkl # Feature scaler |
| ├── benchmark.py # Full benchmark orchestration |
| ├── run_benchmark.py # Resumable benchmark runner |
| └── augment_and_improve.py # Data augmentation + model improvement |
| ``` |
|
|
| ## 📈 Benchmark Data |
|
|
| - **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts |
| - **178 dataset configs**: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection) |
| - **1,600+ algorithm runs**: 9 algorithms × 178 configs with per-algorithm timeout |
|
|
| ### Data Augmentation Strategies |
| - **Variable subsampling**: Drop 20-40% of variables to create virtual sub-networks |
| - **Sample-size variation**: Generate N=300, 750, 1500, 3000 for each network |
| - **Noise injection**: Randomly flip 5-10% of categorical values |
|
|
| ## 🔧 Dependencies |
|
|
| ``` |
| causal-learn>=0.1.4 |
| pgmpy>=0.1.25 |
| scikit-learn>=1.8 |
| pandas |
| numpy |
| scipy |
| joblib |
| ``` |
|
|
| ## 📚 References |
|
|
| - **Causal-Copilot** (arxiv:2504.13263) — Closest existing algorithm selection system |
| - **AVICI** (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture) |
| - **CauScale** (arxiv:2602.08629) — Scalable neural causal discovery |
| - **Dozat & Manning** (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing |
| - **TreeCRF** (arxiv:2005.00975) — Global structural training loss for parsing |
| - **SATzilla** (arxiv:1401.2474) — Algorithm selection via meta-learning |
| - **bnlearn** (bnlearn.com) — Bayesian network benchmark repository |
|
|
| ## 🔮 Future Work (Phase 2) |
| 1. **Biaffine neural encoder**: Pre-train a neural feature extractor that learns variable-pair "arc scores" |
| 2. **Portfolio regret loss** (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE |
| 3. **Hyperparameter co-selection**: Predict not just which algorithm but optimal hyperparameters (CASH) |
| 4. **Ensemble prediction**: Run top-3 and vote on edges across their CPDAGs |
|
|
| ## License |
|
|
| MIT |
|
|