| import yaml |
| import string |
| import secrets |
| import os |
|
|
| import torch |
| import wandb |
| from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint |
| from torchdyn.core import NeuralODE |
|
|
| import torch |
|
|
| @torch.no_grad() |
| def gather_local_starts(x0s, X0_pool, N, k=64): |
| |
| B, G = x0s.shape |
| d2 = torch.cdist(x0s, X0_pool).pow(2) |
| knn_idx = d2.topk(k=min(k, X0_pool.size(0)), largest=False).indices |
| x0_clusters = [] |
| for b in range(B): |
| choices = knn_idx[b] |
| pick = choices[torch.randperm(choices.numel(), device=choices.device)[:N]] |
| x0_clusters.append(X0_pool[pick]) |
| return torch.stack(x0_clusters, dim=0) |
|
|
| @torch.no_grad() |
| def make_aligned_clusters(ot_sampler, x0s, x1s, N, replace=True, k_local=128): |
| |
| device, dtype = x0s.device, x0s.dtype |
| |
| B, G = x0s.shape |
| M = x1s.shape[0] |
| |
| x0_clusters = gather_local_starts(x0s, x0s, N, k=k_local).to(device=device, dtype=dtype) |
| x1_clusters = torch.empty((B, N, G), device=device, dtype=dtype) |
| idx1 = torch.empty((B, N), device=device, dtype=torch.long) |
|
|
| |
| P = None |
| if hasattr(ot_sampler, "coupling"): |
| P = ot_sampler.coupling(x0s, x1s) |
| elif hasattr(ot_sampler, "plan"): |
| P = ot_sampler.plan(x0s, x1s) |
| |
|
|
| for b in range(B): |
| x0_b = x0s[b:b+1] |
|
|
| if P is not None: |
| |
| probs = P[b].clamp_min(0) |
| probs = probs / probs.sum().clamp_min(1e-12) |
| if replace: |
| j = torch.multinomial(probs, num_samples=N, replacement=True) |
| else: |
| k = min(N, (probs > 0).sum().item()) |
| j = torch.multinomial(probs, num_samples=k, replacement=False) |
| if k < N: |
| j = torch.cat([j, j[-1:].expand(N-k)], dim=0) |
| x1_match = x1s[j] |
| else: |
| |
| |
| got = False |
| if hasattr(ot_sampler, "sample_plan"): |
| try: |
| |
| x0_rep, x1_match = ot_sampler.sample_plan( |
| x0_b, x1s, replace=replace, n_pairs=N |
| ) |
| |
| x1_match = x1_match.view(N, G) |
| got = True |
| except TypeError: |
| pass |
| if not got: |
| |
| xs, ys, js = [], [], [] |
| for _ in range(N): |
| x0_rep, x1_one = ot_sampler.sample_plan(x0_b, x1s, replace=replace) |
| |
| j_hat = torch.cdist(x1_one.view(1, -1), x1s).argmin() |
| xs.append(x0_rep.view(1, G)) |
| ys.append(x1_one.view(1, G)) |
| js.append(j_hat.view(1)) |
| x1_match = torch.cat(ys, dim=0) |
| j = torch.cat(js, dim=0) |
|
|
| |
| |
| x1_clusters[b] = x1_match |
| idx1[b] = j |
|
|
| return x0_clusters, x1_clusters, idx1 |
|
|
|
|
| def load_config(path): |
| with open(path, "r") as file: |
| config = yaml.safe_load(file) |
| return config |
|
|
|
|
| def merge_config(args, config_updates): |
| for key, value in config_updates.items(): |
| if not hasattr(args, key): |
| raise ValueError( |
| f"Unknown configuration parameter '{key}' found in the config file." |
| ) |
| setattr(args, key, value) |
| return args |
|
|
|
|
| def generate_group_string(length=16): |
| alphabet = string.ascii_letters + string.digits |
| return "".join(secrets.choice(alphabet) for _ in range(length)) |