Oguzz07's picture
Add causal_selection/meta_learner/predictor.py
eabf58d verified
"""
Inference pipeline: given a new discrete dataset, predict the top-3 causal discovery algorithms.
"""
import numpy as np
import pandas as pd
import logging
import json
from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES, features_to_vector
from causal_selection.meta_learner.trainer import load_model, ALGO_NAMES
from causal_selection.discovery.algorithms import ALGORITHM_POOL
logger = logging.getLogger(__name__)
def predict_best_algorithms(df, k=3, model=None, scaler=None, verbose=True):
"""Given a new discrete dataset, predict the top-k best causal discovery algorithms.
Args:
df: pd.DataFrame with integer-encoded discrete columns
k: number of top algorithms to recommend
model: pre-loaded model (optional, loaded from disk if None)
scaler: pre-loaded scaler (optional)
verbose: print details
Returns:
dict with:
- 'top_k': list of (algo_name, predicted_score) tuples, best first
- 'full_ranking': list of all (algo_name, predicted_score)
- 'meta_features': dict of extracted features
- 'confidence': estimated confidence based on prediction spread
"""
# Load model if not provided
if model is None or scaler is None:
model, scaler = load_model()
# Extract meta-features
if verbose:
print(f"Dataset shape: {df.shape}")
print(f"Extracting meta-features...")
features = extract_all_features(df)
feature_vector = features_to_vector(features).reshape(1, -1)
# Scale and predict
X_scaled = scaler.transform(feature_vector)
predicted_scores = model.predict(X_scaled)[0] # normalized SHD predictions
# Rank algorithms (lower predicted score = better)
ranking_indices = np.argsort(predicted_scores)
full_ranking = [(ALGO_NAMES[i], float(predicted_scores[i])) for i in ranking_indices]
top_k = full_ranking[:k]
# Confidence: how much better is top-1 vs others?
scores_sorted = sorted(predicted_scores)
spread = scores_sorted[-1] - scores_sorted[0] if len(scores_sorted) > 1 else 0
gap_top1_top2 = scores_sorted[1] - scores_sorted[0] if len(scores_sorted) > 1 else 0
result = {
'top_k': top_k,
'full_ranking': full_ranking,
'meta_features': features,
'confidence': {
'score_spread': spread,
'top1_top2_gap': gap_top1_top2,
'recommendation': _get_confidence_text(gap_top1_top2, spread),
}
}
if verbose:
print(f"\n{'='*60}")
print(f"TOP-{k} ALGORITHM RECOMMENDATIONS")
print(f"{'='*60}")
for rank, (algo, score) in enumerate(top_k, 1):
algo_info = ALGORITHM_POOL[algo]
print(f"\n #{rank}: {algo}")
print(f" Predicted nSHD: {score:.4f}")
print(f" Family: {algo_info['family']}")
print(f" Output: {algo_info['output_type']}")
print(f" Library: {algo_info['library']}")
print(f"\n{'='*60}")
print(f"FULL RANKING")
print(f"{'='*60}")
for rank, (algo, score) in enumerate(full_ranking, 1):
marker = " <<<" if rank <= k else ""
print(f" {rank:2d}. {algo:15s} nSHD={score:.4f}{marker}")
print(f"\nConfidence: {result['confidence']['recommendation']}")
# Key dataset properties
print(f"\n{'='*60}")
print(f"DATASET CHARACTERISTICS")
print(f"{'='*60}")
print(f" Variables: {features['n_variables']:.0f}")
print(f" Samples: {features['n_samples']:.0f}")
print(f" N/P ratio: {features['n_over_p']:.1f}")
print(f" Avg cardinality: {features['avg_cardinality']:.1f}")
print(f" Density proxy: {features['density_proxy']:.3f}")
print(f" Mean MI: {features['mean_pairwise_MI']:.4f}")
print(f" V-structure proxy: {features['v_structure_proxy']:.3f}")
return result
def _get_confidence_text(gap, spread):
"""Generate human-readable confidence assessment."""
if spread < 0.01:
return "LOW - All algorithms predicted to perform similarly. Consider running top-3 and comparing."
elif gap > 0.05:
return "HIGH - Clear winner predicted. Top-1 algorithm strongly recommended."
elif gap > 0.02:
return "MEDIUM - Top algorithms are close. Running top-3 recommended for comparison."
else:
return "LOW-MEDIUM - Marginal differences between top algorithms. Run all top-3."
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
# Demo: predict on Asia network
from causal_selection.data.generator import load_bn_model, sample_dataset
model = load_bn_model('sachs')
df = sample_dataset(model, 2000, seed=99)
result = predict_best_algorithms(df, k=3, verbose=True)