""" Inference pipeline: given a new discrete dataset, predict the top-3 causal discovery algorithms. """ import numpy as np import pandas as pd import logging import json from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES, features_to_vector from causal_selection.meta_learner.trainer import load_model, ALGO_NAMES from causal_selection.discovery.algorithms import ALGORITHM_POOL logger = logging.getLogger(__name__) def predict_best_algorithms(df, k=3, model=None, scaler=None, verbose=True): """Given a new discrete dataset, predict the top-k best causal discovery algorithms. Args: df: pd.DataFrame with integer-encoded discrete columns k: number of top algorithms to recommend model: pre-loaded model (optional, loaded from disk if None) scaler: pre-loaded scaler (optional) verbose: print details Returns: dict with: - 'top_k': list of (algo_name, predicted_score) tuples, best first - 'full_ranking': list of all (algo_name, predicted_score) - 'meta_features': dict of extracted features - 'confidence': estimated confidence based on prediction spread """ # Load model if not provided if model is None or scaler is None: model, scaler = load_model() # Extract meta-features if verbose: print(f"Dataset shape: {df.shape}") print(f"Extracting meta-features...") features = extract_all_features(df) feature_vector = features_to_vector(features).reshape(1, -1) # Scale and predict X_scaled = scaler.transform(feature_vector) predicted_scores = model.predict(X_scaled)[0] # normalized SHD predictions # Rank algorithms (lower predicted score = better) ranking_indices = np.argsort(predicted_scores) full_ranking = [(ALGO_NAMES[i], float(predicted_scores[i])) for i in ranking_indices] top_k = full_ranking[:k] # Confidence: how much better is top-1 vs others? scores_sorted = sorted(predicted_scores) spread = scores_sorted[-1] - scores_sorted[0] if len(scores_sorted) > 1 else 0 gap_top1_top2 = scores_sorted[1] - scores_sorted[0] if len(scores_sorted) > 1 else 0 result = { 'top_k': top_k, 'full_ranking': full_ranking, 'meta_features': features, 'confidence': { 'score_spread': spread, 'top1_top2_gap': gap_top1_top2, 'recommendation': _get_confidence_text(gap_top1_top2, spread), } } if verbose: print(f"\n{'='*60}") print(f"TOP-{k} ALGORITHM RECOMMENDATIONS") print(f"{'='*60}") for rank, (algo, score) in enumerate(top_k, 1): algo_info = ALGORITHM_POOL[algo] print(f"\n #{rank}: {algo}") print(f" Predicted nSHD: {score:.4f}") print(f" Family: {algo_info['family']}") print(f" Output: {algo_info['output_type']}") print(f" Library: {algo_info['library']}") print(f"\n{'='*60}") print(f"FULL RANKING") print(f"{'='*60}") for rank, (algo, score) in enumerate(full_ranking, 1): marker = " <<<" if rank <= k else "" print(f" {rank:2d}. {algo:15s} nSHD={score:.4f}{marker}") print(f"\nConfidence: {result['confidence']['recommendation']}") # Key dataset properties print(f"\n{'='*60}") print(f"DATASET CHARACTERISTICS") print(f"{'='*60}") print(f" Variables: {features['n_variables']:.0f}") print(f" Samples: {features['n_samples']:.0f}") print(f" N/P ratio: {features['n_over_p']:.1f}") print(f" Avg cardinality: {features['avg_cardinality']:.1f}") print(f" Density proxy: {features['density_proxy']:.3f}") print(f" Mean MI: {features['mean_pairwise_MI']:.4f}") print(f" V-structure proxy: {features['v_structure_proxy']:.3f}") return result def _get_confidence_text(gap, spread): """Generate human-readable confidence assessment.""" if spread < 0.01: return "LOW - All algorithms predicted to perform similarly. Consider running top-3 and comparing." elif gap > 0.05: return "HIGH - Clear winner predicted. Top-1 algorithm strongly recommended." elif gap > 0.02: return "MEDIUM - Top algorithms are close. Running top-3 recommended for comparison." else: return "LOW-MEDIUM - Marginal differences between top algorithms. Run all top-3." if __name__ == '__main__': logging.basicConfig(level=logging.INFO) # Demo: predict on Asia network from causal_selection.data.generator import load_bn_model, sample_dataset model = load_bn_model('sachs') df = sample_dataset(model, 2000, seed=99) result = predict_best_algorithms(df, k=3, verbose=True)