File size: 10,539 Bytes
9efaa0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 | """
Data generation module: load bnlearn networks, sample datasets, extract ground truth.
"""
import os
import numpy as np
import pandas as pd
from pgmpy.readwrite import BIFReader
from pgmpy.sampling import BayesianModelSampling
import warnings
import logging
warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)
BIF_DIR = os.path.join(os.path.dirname(__file__), 'bif_files')
# Network tiers for CPU budget management
SMALL_NETWORKS = ['asia', 'cancer', 'earthquake', 'sachs', 'survey']
MEDIUM_NETWORKS = ['alarm', 'barley', 'child', 'insurance', 'mildew', 'water']
LARGE_NETWORKS = ['hailfinder', 'hepar2', 'win95pts']
ALL_NETWORKS = SMALL_NETWORKS + MEDIUM_NETWORKS + LARGE_NETWORKS
# Sample sizes per tier
SAMPLE_SIZES = {
'small': [250, 500, 1000, 2000, 5000, 10000],
'medium': [500, 1000, 2000, 5000],
'large': [500, 1000, 2000],
}
SEEDS_PER_CONFIG = 3
def get_network_tier(name):
if name in SMALL_NETWORKS:
return 'small'
elif name in MEDIUM_NETWORKS:
return 'medium'
else:
return 'large'
def load_bn_model(name):
"""Load a Bayesian network from BIF file."""
bif_path = os.path.join(BIF_DIR, f'{name}.bif')
if not os.path.exists(bif_path):
raise FileNotFoundError(f"BIF file not found: {bif_path}")
reader = BIFReader(bif_path)
model = reader.get_model()
return model
def get_true_dag_adjmat(model):
"""Extract ground-truth DAG adjacency matrix from a BayesianNetwork model.
Returns:
adjmat: np.ndarray of shape (n_nodes, n_nodes), adjmat[i,j]=1 means i->j
node_names: list of node names (ordering)
"""
nodes = sorted(model.nodes())
n = len(nodes)
node_idx = {node: i for i, node in enumerate(nodes)}
adjmat = np.zeros((n, n), dtype=int)
for parent, child in model.edges():
adjmat[node_idx[parent], node_idx[child]] = 1
return adjmat, nodes
def dag_to_cpdag(dag_adjmat):
"""Convert a DAG adjacency matrix to its CPDAG (completed partially directed acyclic graph).
A CPDAG represents the Markov equivalence class:
- Compelled edges (in all DAGs of the class) remain directed
- Reversible edges become undirected (represented as bidirectional)
Uses the Chickering (2002) algorithm:
1. Find all v-structures (i -> k <- j where i and j not adjacent)
2. Apply Meek's orientation rules iteratively
Returns:
cpdag: np.ndarray, cpdag[i,j]=1 and cpdag[j,i]=0 means i->j (directed)
cpdag[i,j]=1 and cpdag[j,i]=1 means i--j (undirected)
"""
n = dag_adjmat.shape[0]
# Start with skeleton (undirected)
skeleton = ((dag_adjmat + dag_adjmat.T) > 0).astype(int)
cpdag = skeleton.copy()
# Step 1: Find v-structures and orient them
# v-structure: i -> k <- j where i and j are NOT adjacent in skeleton
for k in range(n):
parents_of_k = np.where(dag_adjmat[:, k] == 1)[0]
for idx_a in range(len(parents_of_k)):
for idx_b in range(idx_a + 1, len(parents_of_k)):
i = parents_of_k[idx_a]
j = parents_of_k[idx_b]
# Check if i and j are NOT adjacent
if skeleton[i, j] == 0:
# This is a v-structure: i -> k <- j
# Orient both edges as directed in CPDAG
cpdag[i, k] = 1
cpdag[k, i] = 0
cpdag[j, k] = 1
cpdag[k, j] = 0
# Step 2: Apply Meek's rules iteratively until convergence
changed = True
while changed:
changed = False
for i in range(n):
for j in range(n):
if cpdag[i, j] == 1 and cpdag[j, i] == 1:
# i -- j is undirected, try to orient
# Rule 1: If k -> i -- j and k not adj j, then i -> j
for k in range(n):
if k != i and k != j:
if cpdag[k, i] == 1 and cpdag[i, k] == 0: # k -> i
if cpdag[k, j] == 0 and cpdag[j, k] == 0: # k not adj j
cpdag[j, i] = 0 # orient i -> j
changed = True
# Rule 2: If i -> k -> j and i -- j, then i -> j
if cpdag[i, j] == 1 and cpdag[j, i] == 1: # still undirected
for k in range(n):
if k != i and k != j:
if (cpdag[i, k] == 1 and cpdag[k, i] == 0 and # i -> k
cpdag[k, j] == 1 and cpdag[j, k] == 0): # k -> j
cpdag[j, i] = 0 # orient i -> j
changed = True
# Rule 3: If i -- k1 -> j and i -- k2 -> j and k1 not adj k2, then i -> j
if cpdag[i, j] == 1 and cpdag[j, i] == 1:
k_candidates = []
for k in range(n):
if k != i and k != j:
if (cpdag[i, k] == 1 and cpdag[k, i] == 1 and # i -- k
cpdag[k, j] == 1 and cpdag[j, k] == 0): # k -> j
k_candidates.append(k)
for idx_a in range(len(k_candidates)):
for idx_b in range(idx_a + 1, len(k_candidates)):
k1, k2 = k_candidates[idx_a], k_candidates[idx_b]
if cpdag[k1, k2] == 0 and cpdag[k2, k1] == 0: # not adjacent
cpdag[j, i] = 0 # orient i -> j
changed = True
return cpdag
def sample_dataset(model, n_samples, seed=42):
"""Sample observational data from a Bayesian network.
Returns:
df: pd.DataFrame with integer-encoded discrete variables
"""
np.random.seed(seed)
sampler = BayesianModelSampling(model)
try:
df = sampler.forward_sample(size=n_samples, seed=seed)
except TypeError:
# Fallback for pgmpy/pandas version compatibility issues
# Use bnlearn sampling or manual forward sampling
df = _manual_forward_sample(model, n_samples, seed)
# Ensure consistent column ordering (sorted)
df = df[sorted(df.columns)]
# Encode string/category columns as integers
for col in df.columns:
if df[col].dtype == object or df[col].dtype.name == 'category':
df[col] = df[col].astype('category').cat.codes
# Ensure all columns are numeric
df = df.apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
return df
def _manual_forward_sample(model, n_samples, seed=42):
"""Manual forward sampling when pgmpy's sampler has compatibility issues."""
import networkx as nx
rng = np.random.RandomState(seed)
nodes = list(nx.topological_sort(model))
# Get CPDs
cpd_dict = {}
for cpd in model.get_cpds():
cpd_dict[cpd.variable] = cpd
samples = {node: [] for node in nodes}
for _ in range(n_samples):
sample = {}
for node in nodes:
cpd = cpd_dict[node]
parents = cpd.get_evidence()
if not parents:
# Root node - sample from marginal
probs = cpd.get_values().flatten()
probs = probs / probs.sum() # normalize
val = rng.choice(len(probs), p=probs)
else:
# Conditional sampling
parent_vals = tuple(sample[p] for p in parents)
# Get the column of CPT corresponding to parent values
values = cpd.get_values()
state_names = cpd.state_names
# Calculate column index from parent states
col_idx = 0
stride = 1
for p in reversed(parents):
p_card = len(state_names[p])
col_idx += sample[p] * stride
stride *= p_card
probs = values[:, col_idx]
probs = np.abs(probs)
probs = probs / probs.sum()
val = rng.choice(len(probs), p=probs)
sample[node] = val
samples[node].append(val)
return pd.DataFrame(samples)
def generate_all_datasets(networks=None, output_dir=None):
"""Generate all dataset configurations.
Returns list of dicts with:
- network: str
- n_samples: int
- seed: int
- df: pd.DataFrame
- true_dag: np.ndarray
- true_cpdag: np.ndarray
- node_names: list
"""
if networks is None:
networks = ALL_NETWORKS
configs = []
for net_name in networks:
tier = get_network_tier(net_name)
sample_sizes = SAMPLE_SIZES[tier]
logger.info(f"Loading network: {net_name}")
model = load_bn_model(net_name)
true_dag, node_names = get_true_dag_adjmat(model)
true_cpdag = dag_to_cpdag(true_dag)
for n_samples in sample_sizes:
for seed in range(SEEDS_PER_CONFIG):
try:
df = sample_dataset(model, n_samples, seed=seed)
config = {
'network': net_name,
'n_samples': n_samples,
'seed': seed,
'df': df,
'true_dag': true_dag,
'true_cpdag': true_cpdag,
'node_names': node_names,
}
configs.append(config)
logger.info(f" {net_name} N={n_samples} seed={seed}: {df.shape}")
except Exception as e:
logger.error(f" FAILED {net_name} N={n_samples} seed={seed}: {e}")
return configs
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
# Quick test
model = load_bn_model('asia')
dag, nodes = get_true_dag_adjmat(model)
cpdag = dag_to_cpdag(dag)
print(f"ASIA - nodes: {nodes}")
print(f"DAG adjacency:\n{dag}")
print(f"CPDAG adjacency:\n{cpdag}")
df = sample_dataset(model, 1000, seed=0)
print(f"\nSampled data: {df.shape}")
print(df.head())
|