| """ |
| Inference pipeline: given a new discrete dataset, predict the top-3 causal discovery algorithms. |
| """ |
| import numpy as np |
| import pandas as pd |
| import logging |
| import json |
|
|
| from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES, features_to_vector |
| from causal_selection.meta_learner.trainer import load_model, ALGO_NAMES |
| from causal_selection.discovery.algorithms import ALGORITHM_POOL |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def predict_best_algorithms(df, k=3, model=None, scaler=None, verbose=True): |
| """Given a new discrete dataset, predict the top-k best causal discovery algorithms. |
| |
| Args: |
| df: pd.DataFrame with integer-encoded discrete columns |
| k: number of top algorithms to recommend |
| model: pre-loaded model (optional, loaded from disk if None) |
| scaler: pre-loaded scaler (optional) |
| verbose: print details |
| |
| Returns: |
| dict with: |
| - 'top_k': list of (algo_name, predicted_score) tuples, best first |
| - 'full_ranking': list of all (algo_name, predicted_score) |
| - 'meta_features': dict of extracted features |
| - 'confidence': estimated confidence based on prediction spread |
| """ |
| |
| if model is None or scaler is None: |
| model, scaler = load_model() |
| |
| |
| if verbose: |
| print(f"Dataset shape: {df.shape}") |
| print(f"Extracting meta-features...") |
| |
| features = extract_all_features(df) |
| feature_vector = features_to_vector(features).reshape(1, -1) |
| |
| |
| X_scaled = scaler.transform(feature_vector) |
| predicted_scores = model.predict(X_scaled)[0] |
| |
| |
| ranking_indices = np.argsort(predicted_scores) |
| |
| full_ranking = [(ALGO_NAMES[i], float(predicted_scores[i])) for i in ranking_indices] |
| top_k = full_ranking[:k] |
| |
| |
| scores_sorted = sorted(predicted_scores) |
| spread = scores_sorted[-1] - scores_sorted[0] if len(scores_sorted) > 1 else 0 |
| gap_top1_top2 = scores_sorted[1] - scores_sorted[0] if len(scores_sorted) > 1 else 0 |
| |
| result = { |
| 'top_k': top_k, |
| 'full_ranking': full_ranking, |
| 'meta_features': features, |
| 'confidence': { |
| 'score_spread': spread, |
| 'top1_top2_gap': gap_top1_top2, |
| 'recommendation': _get_confidence_text(gap_top1_top2, spread), |
| } |
| } |
| |
| if verbose: |
| print(f"\n{'='*60}") |
| print(f"TOP-{k} ALGORITHM RECOMMENDATIONS") |
| print(f"{'='*60}") |
| for rank, (algo, score) in enumerate(top_k, 1): |
| algo_info = ALGORITHM_POOL[algo] |
| print(f"\n #{rank}: {algo}") |
| print(f" Predicted nSHD: {score:.4f}") |
| print(f" Family: {algo_info['family']}") |
| print(f" Output: {algo_info['output_type']}") |
| print(f" Library: {algo_info['library']}") |
| |
| print(f"\n{'='*60}") |
| print(f"FULL RANKING") |
| print(f"{'='*60}") |
| for rank, (algo, score) in enumerate(full_ranking, 1): |
| marker = " <<<" if rank <= k else "" |
| print(f" {rank:2d}. {algo:15s} nSHD={score:.4f}{marker}") |
| |
| print(f"\nConfidence: {result['confidence']['recommendation']}") |
| |
| |
| print(f"\n{'='*60}") |
| print(f"DATASET CHARACTERISTICS") |
| print(f"{'='*60}") |
| print(f" Variables: {features['n_variables']:.0f}") |
| print(f" Samples: {features['n_samples']:.0f}") |
| print(f" N/P ratio: {features['n_over_p']:.1f}") |
| print(f" Avg cardinality: {features['avg_cardinality']:.1f}") |
| print(f" Density proxy: {features['density_proxy']:.3f}") |
| print(f" Mean MI: {features['mean_pairwise_MI']:.4f}") |
| print(f" V-structure proxy: {features['v_structure_proxy']:.3f}") |
| |
| return result |
|
|
|
|
| def _get_confidence_text(gap, spread): |
| """Generate human-readable confidence assessment.""" |
| if spread < 0.01: |
| return "LOW - All algorithms predicted to perform similarly. Consider running top-3 and comparing." |
| elif gap > 0.05: |
| return "HIGH - Clear winner predicted. Top-1 algorithm strongly recommended." |
| elif gap > 0.02: |
| return "MEDIUM - Top algorithms are close. Running top-3 recommended for comparison." |
| else: |
| return "LOW-MEDIUM - Marginal differences between top algorithms. Run all top-3." |
|
|
|
|
| if __name__ == '__main__': |
| logging.basicConfig(level=logging.INFO) |
| |
| |
| from causal_selection.data.generator import load_bn_model, sample_dataset |
| |
| model = load_bn_model('sachs') |
| df = sample_dataset(model, 2000, seed=99) |
| |
| result = predict_best_algorithms(df, k=3, verbose=True) |
|
|