Oguzz07's picture
Update README with final results (71.3% hit rate, 178 configs, pairwise ranking)
70796a3 verified
# Causal Discovery Algorithm Selection Meta-Learner
A meta-learning system that predicts the **top-3 best causal discovery algorithms** for any discrete observational dataset, based on dataset meta-features.
## 🎯 What it Does
Given a new discrete dataset (pandas DataFrame), the system:
1. **Extracts 34 meta-features** (entropy, mutual information, chi² statistics, CI test probes, etc.)
2. **Predicts normalized SHD** for each of 9 algorithms via trained models
3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG
## 📊 Performance (Leave-One-Network-Out Cross-Validation)
### Best Model: Pairwise-GBM Ranking
| Metric | Value |
|--------|-------|
| **Top-3 Hit Rate** | **71.3%** (true best algorithm is in predicted top-3) |
| **Mean Regret** | **0.011** (tiny SHD gap vs oracle selection) |
| **Median Regret** | **0.000** (majority of predictions are perfect) |
### Model Comparison (178 configs, 14 networks + augmented)
| Model | Top-3 Hit Rate | NDCG@3 | Mean Regret |
|-------|---------------|--------|-------------|
| **Pairwise-GBM** | **71.3%** | — | 0.011 |
| GBM-300-lr01 | 67.4% | 0.957 | 0.011 |
| RF-200 | 66.9% | 0.961 | 0.007 |
| RF-500 | 66.3% | 0.962 | 0.007 |
| GBM-500-lr05 | 65.2% | 0.948 | 0.013 |
### Progression
| Stage | Configs | Networks | Top-3 Hit Rate |
|-------|---------|----------|---------------|
| Initial (small nets) | 65 | 4 | 68.2% |
| All 14 networks | 122 | 14 | 70.5% |
| + Data augmentation | 178 | 14+aug | **71.3%** |
## 🧪 Algorithm Pool (9 algorithms)
| Algorithm | Family | Library | Output | Wins |
|-----------|--------|---------|--------|------|
| **GES** | Score-based | causal-learn | CPDAG | 47% |
| **PC** | Constraint-based | causal-learn | CPDAG | 32% |
| **FCI** | Constraint-based | causal-learn | PAG | 8% |
| **K2** | Score-based | pgmpy | DAG | 6% |
| **HC** | Score-based (greedy) | pgmpy | DAG | 3% |
| **Tabu** | Score-based (meta) | pgmpy | DAG | 2% |
| **GRaSP** | Permutation-based | causal-learn | CPDAG | 1% |
| **BOSS** | Permutation-based | causal-learn | CPDAG | 1% |
| **MMHC** | Hybrid | pgmpy | DAG | <1% |
## 🔬 Key Insight: Dependency Parsing Connection
This project was inspired by a structural parallel between **NLP dependency parsing** and **causal discovery**:
- Both predict **directed graphs** over nodes (words/variables)
- Both have **ground-truth annotations** (treebanks/bnlearn networks)
- Both use **arc-level evaluation** (UAS/LAS ↔ SHD/F1)
The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection.
### Top Predictive Meta-Features
1. `n_variables` (30%) — network size (how many nodes in the graph)
2. `max_pairwise_MI` (24%) — strongest pairwise dependency (≈ biaffine arc score)
3. `max_cramers_v` (8%) — strongest association strength
4. `max_entropy` (7%) — variable complexity
### Three Ideas Borrowed from Parsing
1. **Biaffine-style pairwise features**: MI and Cramér's V between all variable pairs = parsing's arc scores
2. **Pairwise ranking** (our best model): For each algorithm pair (A,B), predict which wins → count wins to rank. Inspired by pairwise tournament-style parser selection
3. **Cross-domain transfer**: Train on well-characterized bnlearn networks → predict on new unseen datasets (= cross-lingual parser transfer)
## 🚀 Quick Start
```python
from causal_selection.meta_learner.predictor import predict_best_algorithms
import pandas as pd
# Load your discrete dataset
df = pd.read_csv("my_discrete_data.csv")
# Get top-3 recommendations
result = predict_best_algorithms(df, k=3)
# Prints ranked algorithms with predicted accuracy and confidence
```
## 📁 Project Structure
```
causal_selection/
├── data/
│ ├── generator.py # Load bnlearn networks, sample data, DAG→CPDAG
│ ├── bif_files/ # 14 bnlearn BIF files (asia through win95pts)
│ └── results/ # Benchmark CSVs: meta-features, SHD matrices
├── discovery/
│ ├── algorithms.py # 9 algorithm adapters with timeout handling
│ └── evaluator.py # SHD, F1, Precision, Recall computation
├── features/
│ └── extractor.py # 34 meta-features across 5 tiers
├── meta_learner/
│ ├── trainer.py # Multi-Output RF/GBM + LONO-CV evaluation
│ └── predictor.py # Inference: dataset → top-3 prediction
├── models/
│ ├── meta_learner.pkl # Trained GBM (multi-output fallback)
│ ├── pairwise_model.pkl # Pairwise ranking GBM (best model)
│ └── scaler.pkl # Feature scaler
├── benchmark.py # Full benchmark orchestration
├── run_benchmark.py # Resumable benchmark runner
└── augment_and_improve.py # Data augmentation + model improvement
```
## 📈 Benchmark Data
- **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts
- **178 dataset configs**: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection)
- **1,600+ algorithm runs**: 9 algorithms × 178 configs with per-algorithm timeout
### Data Augmentation Strategies
- **Variable subsampling**: Drop 20-40% of variables to create virtual sub-networks
- **Sample-size variation**: Generate N=300, 750, 1500, 3000 for each network
- **Noise injection**: Randomly flip 5-10% of categorical values
## 🔧 Dependencies
```
causal-learn>=0.1.4
pgmpy>=0.1.25
scikit-learn>=1.8
pandas
numpy
scipy
joblib
```
## 📚 References
- **Causal-Copilot** (arxiv:2504.13263) — Closest existing algorithm selection system
- **AVICI** (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture)
- **CauScale** (arxiv:2602.08629) — Scalable neural causal discovery
- **Dozat & Manning** (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing
- **TreeCRF** (arxiv:2005.00975) — Global structural training loss for parsing
- **SATzilla** (arxiv:1401.2474) — Algorithm selection via meta-learning
- **bnlearn** (bnlearn.com) — Bayesian network benchmark repository
## 🔮 Future Work (Phase 2)
1. **Biaffine neural encoder**: Pre-train a neural feature extractor that learns variable-pair "arc scores"
2. **Portfolio regret loss** (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE
3. **Hyperparameter co-selection**: Predict not just which algorithm but optimal hyperparameters (CASH)
4. **Ensemble prediction**: Run top-3 and vote on edges across their CPDAGs
## License
MIT