# Causal Discovery Algorithm Selection Meta-Learner A meta-learning system that predicts the **top-3 best causal discovery algorithms** for any discrete observational dataset, based on dataset meta-features. ## ๐ŸŽฏ What it Does Given a new discrete dataset (pandas DataFrame), the system: 1. **Extracts 34 meta-features** (entropy, mutual information, chiยฒ statistics, CI test probes, etc.) 2. **Predicts normalized SHD** for each of 9 algorithms via trained models 3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG ## ๐Ÿ“Š Performance (Leave-One-Network-Out Cross-Validation) ### Best Model: Pairwise-GBM Ranking | Metric | Value | |--------|-------| | **Top-3 Hit Rate** | **71.3%** (true best algorithm is in predicted top-3) | | **Mean Regret** | **0.011** (tiny SHD gap vs oracle selection) | | **Median Regret** | **0.000** (majority of predictions are perfect) | ### Model Comparison (178 configs, 14 networks + augmented) | Model | Top-3 Hit Rate | NDCG@3 | Mean Regret | |-------|---------------|--------|-------------| | **Pairwise-GBM** | **71.3%** | โ€” | 0.011 | | GBM-300-lr01 | 67.4% | 0.957 | 0.011 | | RF-200 | 66.9% | 0.961 | 0.007 | | RF-500 | 66.3% | 0.962 | 0.007 | | GBM-500-lr05 | 65.2% | 0.948 | 0.013 | ### Progression | Stage | Configs | Networks | Top-3 Hit Rate | |-------|---------|----------|---------------| | Initial (small nets) | 65 | 4 | 68.2% | | All 14 networks | 122 | 14 | 70.5% | | + Data augmentation | 178 | 14+aug | **71.3%** | ## ๐Ÿงช Algorithm Pool (9 algorithms) | Algorithm | Family | Library | Output | Wins | |-----------|--------|---------|--------|------| | **GES** | Score-based | causal-learn | CPDAG | 47% | | **PC** | Constraint-based | causal-learn | CPDAG | 32% | | **FCI** | Constraint-based | causal-learn | PAG | 8% | | **K2** | Score-based | pgmpy | DAG | 6% | | **HC** | Score-based (greedy) | pgmpy | DAG | 3% | | **Tabu** | Score-based (meta) | pgmpy | DAG | 2% | | **GRaSP** | Permutation-based | causal-learn | CPDAG | 1% | | **BOSS** | Permutation-based | causal-learn | CPDAG | 1% | | **MMHC** | Hybrid | pgmpy | DAG | <1% | ## ๐Ÿ”ฌ Key Insight: Dependency Parsing Connection This project was inspired by a structural parallel between **NLP dependency parsing** and **causal discovery**: - Both predict **directed graphs** over nodes (words/variables) - Both have **ground-truth annotations** (treebanks/bnlearn networks) - Both use **arc-level evaluation** (UAS/LAS โ†” SHD/F1) The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning โ€” validating this connection. ### Top Predictive Meta-Features 1. `n_variables` (30%) โ€” network size (how many nodes in the graph) 2. `max_pairwise_MI` (24%) โ€” strongest pairwise dependency (โ‰ˆ biaffine arc score) 3. `max_cramers_v` (8%) โ€” strongest association strength 4. `max_entropy` (7%) โ€” variable complexity ### Three Ideas Borrowed from Parsing 1. **Biaffine-style pairwise features**: MI and Cramรฉr's V between all variable pairs = parsing's arc scores 2. **Pairwise ranking** (our best model): For each algorithm pair (A,B), predict which wins โ†’ count wins to rank. Inspired by pairwise tournament-style parser selection 3. **Cross-domain transfer**: Train on well-characterized bnlearn networks โ†’ predict on new unseen datasets (= cross-lingual parser transfer) ## ๐Ÿš€ Quick Start ```python from causal_selection.meta_learner.predictor import predict_best_algorithms import pandas as pd # Load your discrete dataset df = pd.read_csv("my_discrete_data.csv") # Get top-3 recommendations result = predict_best_algorithms(df, k=3) # Prints ranked algorithms with predicted accuracy and confidence ``` ## ๐Ÿ“ Project Structure ``` causal_selection/ โ”œโ”€โ”€ data/ โ”‚ โ”œโ”€โ”€ generator.py # Load bnlearn networks, sample data, DAGโ†’CPDAG โ”‚ โ”œโ”€โ”€ bif_files/ # 14 bnlearn BIF files (asia through win95pts) โ”‚ โ””โ”€โ”€ results/ # Benchmark CSVs: meta-features, SHD matrices โ”œโ”€โ”€ discovery/ โ”‚ โ”œโ”€โ”€ algorithms.py # 9 algorithm adapters with timeout handling โ”‚ โ””โ”€โ”€ evaluator.py # SHD, F1, Precision, Recall computation โ”œโ”€โ”€ features/ โ”‚ โ””โ”€โ”€ extractor.py # 34 meta-features across 5 tiers โ”œโ”€โ”€ meta_learner/ โ”‚ โ”œโ”€โ”€ trainer.py # Multi-Output RF/GBM + LONO-CV evaluation โ”‚ โ””โ”€โ”€ predictor.py # Inference: dataset โ†’ top-3 prediction โ”œโ”€โ”€ models/ โ”‚ โ”œโ”€โ”€ meta_learner.pkl # Trained GBM (multi-output fallback) โ”‚ โ”œโ”€โ”€ pairwise_model.pkl # Pairwise ranking GBM (best model) โ”‚ โ””โ”€โ”€ scaler.pkl # Feature scaler โ”œโ”€โ”€ benchmark.py # Full benchmark orchestration โ”œโ”€โ”€ run_benchmark.py # Resumable benchmark runner โ””โ”€โ”€ augment_and_improve.py # Data augmentation + model improvement ``` ## ๐Ÿ“ˆ Benchmark Data - **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts - **178 dataset configs**: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection) - **1,600+ algorithm runs**: 9 algorithms ร— 178 configs with per-algorithm timeout ### Data Augmentation Strategies - **Variable subsampling**: Drop 20-40% of variables to create virtual sub-networks - **Sample-size variation**: Generate N=300, 750, 1500, 3000 for each network - **Noise injection**: Randomly flip 5-10% of categorical values ## ๐Ÿ”ง Dependencies ``` causal-learn>=0.1.4 pgmpy>=0.1.25 scikit-learn>=1.8 pandas numpy scipy joblib ``` ## ๐Ÿ“š References - **Causal-Copilot** (arxiv:2504.13263) โ€” Closest existing algorithm selection system - **AVICI** (arxiv:2205.12934) โ€” Amortized causal structure learning (biaffine architecture) - **CauScale** (arxiv:2602.08629) โ€” Scalable neural causal discovery - **Dozat & Manning** (arxiv:1611.01734) โ€” Deep Biaffine Attention for dependency parsing - **TreeCRF** (arxiv:2005.00975) โ€” Global structural training loss for parsing - **SATzilla** (arxiv:1401.2474) โ€” Algorithm selection via meta-learning - **bnlearn** (bnlearn.com) โ€” Bayesian network benchmark repository ## ๐Ÿ”ฎ Future Work (Phase 2) 1. **Biaffine neural encoder**: Pre-train a neural feature extractor that learns variable-pair "arc scores" 2. **Portfolio regret loss** (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE 3. **Hyperparameter co-selection**: Predict not just which algorithm but optimal hyperparameters (CASH) 4. **Ensemble prediction**: Run top-3 and vote on edges across their CPDAGs ## License MIT