Oguzz07's picture
Add causal_selection/data/generator.py
9efaa0b verified
"""
Data generation module: load bnlearn networks, sample datasets, extract ground truth.
"""
import os
import numpy as np
import pandas as pd
from pgmpy.readwrite import BIFReader
from pgmpy.sampling import BayesianModelSampling
import warnings
import logging
warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)
BIF_DIR = os.path.join(os.path.dirname(__file__), 'bif_files')
# Network tiers for CPU budget management
SMALL_NETWORKS = ['asia', 'cancer', 'earthquake', 'sachs', 'survey']
MEDIUM_NETWORKS = ['alarm', 'barley', 'child', 'insurance', 'mildew', 'water']
LARGE_NETWORKS = ['hailfinder', 'hepar2', 'win95pts']
ALL_NETWORKS = SMALL_NETWORKS + MEDIUM_NETWORKS + LARGE_NETWORKS
# Sample sizes per tier
SAMPLE_SIZES = {
'small': [250, 500, 1000, 2000, 5000, 10000],
'medium': [500, 1000, 2000, 5000],
'large': [500, 1000, 2000],
}
SEEDS_PER_CONFIG = 3
def get_network_tier(name):
if name in SMALL_NETWORKS:
return 'small'
elif name in MEDIUM_NETWORKS:
return 'medium'
else:
return 'large'
def load_bn_model(name):
"""Load a Bayesian network from BIF file."""
bif_path = os.path.join(BIF_DIR, f'{name}.bif')
if not os.path.exists(bif_path):
raise FileNotFoundError(f"BIF file not found: {bif_path}")
reader = BIFReader(bif_path)
model = reader.get_model()
return model
def get_true_dag_adjmat(model):
"""Extract ground-truth DAG adjacency matrix from a BayesianNetwork model.
Returns:
adjmat: np.ndarray of shape (n_nodes, n_nodes), adjmat[i,j]=1 means i->j
node_names: list of node names (ordering)
"""
nodes = sorted(model.nodes())
n = len(nodes)
node_idx = {node: i for i, node in enumerate(nodes)}
adjmat = np.zeros((n, n), dtype=int)
for parent, child in model.edges():
adjmat[node_idx[parent], node_idx[child]] = 1
return adjmat, nodes
def dag_to_cpdag(dag_adjmat):
"""Convert a DAG adjacency matrix to its CPDAG (completed partially directed acyclic graph).
A CPDAG represents the Markov equivalence class:
- Compelled edges (in all DAGs of the class) remain directed
- Reversible edges become undirected (represented as bidirectional)
Uses the Chickering (2002) algorithm:
1. Find all v-structures (i -> k <- j where i and j not adjacent)
2. Apply Meek's orientation rules iteratively
Returns:
cpdag: np.ndarray, cpdag[i,j]=1 and cpdag[j,i]=0 means i->j (directed)
cpdag[i,j]=1 and cpdag[j,i]=1 means i--j (undirected)
"""
n = dag_adjmat.shape[0]
# Start with skeleton (undirected)
skeleton = ((dag_adjmat + dag_adjmat.T) > 0).astype(int)
cpdag = skeleton.copy()
# Step 1: Find v-structures and orient them
# v-structure: i -> k <- j where i and j are NOT adjacent in skeleton
for k in range(n):
parents_of_k = np.where(dag_adjmat[:, k] == 1)[0]
for idx_a in range(len(parents_of_k)):
for idx_b in range(idx_a + 1, len(parents_of_k)):
i = parents_of_k[idx_a]
j = parents_of_k[idx_b]
# Check if i and j are NOT adjacent
if skeleton[i, j] == 0:
# This is a v-structure: i -> k <- j
# Orient both edges as directed in CPDAG
cpdag[i, k] = 1
cpdag[k, i] = 0
cpdag[j, k] = 1
cpdag[k, j] = 0
# Step 2: Apply Meek's rules iteratively until convergence
changed = True
while changed:
changed = False
for i in range(n):
for j in range(n):
if cpdag[i, j] == 1 and cpdag[j, i] == 1:
# i -- j is undirected, try to orient
# Rule 1: If k -> i -- j and k not adj j, then i -> j
for k in range(n):
if k != i and k != j:
if cpdag[k, i] == 1 and cpdag[i, k] == 0: # k -> i
if cpdag[k, j] == 0 and cpdag[j, k] == 0: # k not adj j
cpdag[j, i] = 0 # orient i -> j
changed = True
# Rule 2: If i -> k -> j and i -- j, then i -> j
if cpdag[i, j] == 1 and cpdag[j, i] == 1: # still undirected
for k in range(n):
if k != i and k != j:
if (cpdag[i, k] == 1 and cpdag[k, i] == 0 and # i -> k
cpdag[k, j] == 1 and cpdag[j, k] == 0): # k -> j
cpdag[j, i] = 0 # orient i -> j
changed = True
# Rule 3: If i -- k1 -> j and i -- k2 -> j and k1 not adj k2, then i -> j
if cpdag[i, j] == 1 and cpdag[j, i] == 1:
k_candidates = []
for k in range(n):
if k != i and k != j:
if (cpdag[i, k] == 1 and cpdag[k, i] == 1 and # i -- k
cpdag[k, j] == 1 and cpdag[j, k] == 0): # k -> j
k_candidates.append(k)
for idx_a in range(len(k_candidates)):
for idx_b in range(idx_a + 1, len(k_candidates)):
k1, k2 = k_candidates[idx_a], k_candidates[idx_b]
if cpdag[k1, k2] == 0 and cpdag[k2, k1] == 0: # not adjacent
cpdag[j, i] = 0 # orient i -> j
changed = True
return cpdag
def sample_dataset(model, n_samples, seed=42):
"""Sample observational data from a Bayesian network.
Returns:
df: pd.DataFrame with integer-encoded discrete variables
"""
np.random.seed(seed)
sampler = BayesianModelSampling(model)
try:
df = sampler.forward_sample(size=n_samples, seed=seed)
except TypeError:
# Fallback for pgmpy/pandas version compatibility issues
# Use bnlearn sampling or manual forward sampling
df = _manual_forward_sample(model, n_samples, seed)
# Ensure consistent column ordering (sorted)
df = df[sorted(df.columns)]
# Encode string/category columns as integers
for col in df.columns:
if df[col].dtype == object or df[col].dtype.name == 'category':
df[col] = df[col].astype('category').cat.codes
# Ensure all columns are numeric
df = df.apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
return df
def _manual_forward_sample(model, n_samples, seed=42):
"""Manual forward sampling when pgmpy's sampler has compatibility issues."""
import networkx as nx
rng = np.random.RandomState(seed)
nodes = list(nx.topological_sort(model))
# Get CPDs
cpd_dict = {}
for cpd in model.get_cpds():
cpd_dict[cpd.variable] = cpd
samples = {node: [] for node in nodes}
for _ in range(n_samples):
sample = {}
for node in nodes:
cpd = cpd_dict[node]
parents = cpd.get_evidence()
if not parents:
# Root node - sample from marginal
probs = cpd.get_values().flatten()
probs = probs / probs.sum() # normalize
val = rng.choice(len(probs), p=probs)
else:
# Conditional sampling
parent_vals = tuple(sample[p] for p in parents)
# Get the column of CPT corresponding to parent values
values = cpd.get_values()
state_names = cpd.state_names
# Calculate column index from parent states
col_idx = 0
stride = 1
for p in reversed(parents):
p_card = len(state_names[p])
col_idx += sample[p] * stride
stride *= p_card
probs = values[:, col_idx]
probs = np.abs(probs)
probs = probs / probs.sum()
val = rng.choice(len(probs), p=probs)
sample[node] = val
samples[node].append(val)
return pd.DataFrame(samples)
def generate_all_datasets(networks=None, output_dir=None):
"""Generate all dataset configurations.
Returns list of dicts with:
- network: str
- n_samples: int
- seed: int
- df: pd.DataFrame
- true_dag: np.ndarray
- true_cpdag: np.ndarray
- node_names: list
"""
if networks is None:
networks = ALL_NETWORKS
configs = []
for net_name in networks:
tier = get_network_tier(net_name)
sample_sizes = SAMPLE_SIZES[tier]
logger.info(f"Loading network: {net_name}")
model = load_bn_model(net_name)
true_dag, node_names = get_true_dag_adjmat(model)
true_cpdag = dag_to_cpdag(true_dag)
for n_samples in sample_sizes:
for seed in range(SEEDS_PER_CONFIG):
try:
df = sample_dataset(model, n_samples, seed=seed)
config = {
'network': net_name,
'n_samples': n_samples,
'seed': seed,
'df': df,
'true_dag': true_dag,
'true_cpdag': true_cpdag,
'node_names': node_names,
}
configs.append(config)
logger.info(f" {net_name} N={n_samples} seed={seed}: {df.shape}")
except Exception as e:
logger.error(f" FAILED {net_name} N={n_samples} seed={seed}: {e}")
return configs
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
# Quick test
model = load_bn_model('asia')
dag, nodes = get_true_dag_adjmat(model)
cpdag = dag_to_cpdag(dag)
print(f"ASIA - nodes: {nodes}")
print(f"DAG adjacency:\n{dag}")
print(f"CPDAG adjacency:\n{cpdag}")
df = sample_dataset(model, 1000, seed=0)
print(f"\nSampled data: {df.shape}")
print(df.head())