| import torch |
| from flow_matching.utils import categorical |
| import math |
| import inspect |
|
|
| def generate_simplex_lattice_points(num_obj: int, num_div: int) -> torch.Tensor: |
| def rec(n, H): |
| if n == 1: |
| return [[H]] |
| points = [] |
| for i in range(H + 1): |
| for tail in rec(n - 1, H - i): |
| points.append([i] + tail) |
| return points |
|
|
| points = rec(num_obj, num_div) |
| weight_vectors = torch.tensor(points, dtype=torch.float32) / num_div |
| return weight_vectors |
|
|
| def select_random_weight_vector(num_obj: int, num_div: int): |
| weight_vectors = generate_simplex_lattice_points(num_obj, num_div) |
| idx = torch.randint(0, weight_vectors.size(0), (1,)).item() |
| random_weight_vector = weight_vectors[idx] |
| return random_weight_vector, weight_vectors |
|
|
| def z_score_norm(tensor, eps=1e-8): |
| mean = tensor.mean(dim=-1, keepdim=True) |
| std = tensor.std(dim=-1, unbiased=False, keepdim=True).clamp(min=eps) |
| return (tensor - mean) / std |
|
|
| def guided_transition_scoring(x_t, u_t, w, s_models, t, importance, args): |
| B, L, vocab_size = u_t.shape |
| device = x_t.device |
| guided_u_t = u_t.clone() |
| |
| |
| pos_indices = torch.randint(low=1, high=L-2, size=(B,), device=device) |
| batch_idx = torch.arange(B, device=device) |
| current_tokens = x_t[batch_idx, pos_indices] |
|
|
| |
| full_cand_tokens = torch.arange(vocab_size, device=device).unsqueeze(0).expand(B, vocab_size) |
| mask = (full_cand_tokens != current_tokens.unsqueeze(1)) |
| |
| cand_tokens = torch.masked_select(full_cand_tokens, mask).view(B, vocab_size - 1) |
|
|
| |
| new_x = x_t.unsqueeze(1).expand(B, vocab_size, L).clone() |
| new_x = new_x[mask].view(B, vocab_size - 1, L) |
| new_x[batch_idx, :, pos_indices] = cand_tokens |
|
|
| new_x_flat = new_x.view(B * (vocab_size - 1), L) |
| improvements_list = [] |
| with torch.no_grad(): |
| count = 0 |
| for i, s in enumerate(s_models): |
| sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
| if 't' in sig.parameters: |
| candidate_scores = s(new_x_flat, t) |
| base_score = s(x_t, t) |
| else: |
| candidate_scores = s(new_x_flat) |
| base_score = s(x_t) |
|
|
| if isinstance(candidate_scores, tuple): |
| for k, score in enumerate(candidate_scores): |
| improvement = candidate_scores[k].view(B, vocab_size - 1) - base_score[k].unsqueeze(1) |
| improvement = improvement.float() |
| improvement *= importance[count] |
| improvements_list.append(improvement.unsqueeze(2)) |
| count += 1 |
| else: |
| improvement = candidate_scores.view(B, vocab_size - 1) - base_score.unsqueeze(1) |
| improvement = improvement.float() |
| improvement *= importance[count] |
| improvements_list.append(improvement.unsqueeze(2)) |
| count += 1 |
|
|
| improvement_values = torch.cat(improvements_list, dim=2) |
| if args.is_peptide: |
| improvement_values[:, :4, :] = -10 |
|
|
| |
| ranks = torch.argsort(torch.argsort(improvement_values, dim=1), dim=1).float() + 1 |
| I_n = ranks / float(vocab_size - 1) |
| avg_I = I_n.mean(dim=2) |
| norm_avg_I = z_score_norm(avg_I) |
| |
| |
| D = (improvement_values * w.view(1, 1, -1)).sum(dim=2) |
| norm_D = z_score_norm(D) |
|
|
| |
| delta_S = norm_avg_I + args.lambda_ * norm_D |
|
|
| |
| factor = torch.exp(args.beta * delta_S) |
| factor = torch.clamp(factor, min=-100, max=100) |
|
|
| guided_u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] = u_t[batch_idx.unsqueeze(1), pos_indices.unsqueeze(1), cand_tokens] * factor |
|
|
| |
| |
| updated_vals = guided_u_t[batch_idx, pos_indices, :] |
| sum_off_diag = updated_vals.sum(dim=1) - updated_vals[batch_idx, current_tokens] |
| guided_u_t[batch_idx, pos_indices, current_tokens] = -sum_off_diag |
|
|
| return guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S |
|
|
| def adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=None): |
| B, num_candidates, N = improvement_values.shape |
| device = improvement_values.device |
| eps = 1e-8 |
|
|
| |
| imp_norm = torch.norm(improvement_values.float(), dim=2) |
| dot_product = (improvement_values * w.view(1, 1, -1)).sum(dim=2) |
| w_norm = torch.norm(w) + eps |
| cos_angle = dot_product / (imp_norm * w_norm + eps) |
| cos_angle = cos_angle.clamp(-1.0, 1.0) |
| angles = torch.acos(cos_angle) |
|
|
| valid_mask = angles < math.pi / 2 |
| accepted_mask = valid_mask & (angles <= Phi) |
|
|
| |
| |
| best_candidate = torch.empty(B, dtype=torch.long, device=device) |
| for i in range(B): |
| |
| if valid_mask[i].any(): |
| |
| if accepted_mask[i].any(): |
| |
| candidate_idx = torch.argmax(delta_S[i].masked_fill(~accepted_mask[i], float('-inf'))) |
| else: |
| |
| candidate_idx = torch.argmax(delta_S[i].masked_fill(~valid_mask[i], float('-inf'))) |
| best_candidate[i] = cand_tokens[i, candidate_idx] |
| else: |
| |
| best_candidate[i] = -1 |
|
|
| |
| rejection_rates = [] |
| for i in range(B): |
| valid_candidates = valid_mask[i] |
| total_valid = valid_candidates.sum().item() |
| if total_valid > 0: |
| |
| num_rejected = (valid_candidates.sum() - accepted_mask[i].sum()).item() |
| rejection_rates.append(num_rejected / total_valid) |
| if len(rejection_rates) > 0: |
| r_t = sum(rejection_rates) / len(rejection_rates) |
| else: |
| |
| r_t = 0.0 |
|
|
| if ema_r_t is None: |
| ema_r_t = args.tau |
|
|
| |
| if valid_mask.any(): |
| new_ema_r_t = args.alpha_r * ema_r_t + (1 - args.alpha_r) * r_t |
| new_Phi = Phi * torch.exp(torch.tensor(args.eta * (new_ema_r_t - args.tau), device=device)) |
| new_Phi = new_Phi.clamp(args.Phi_min, args.Phi_max).item() |
| else: |
| new_ema_r_t = ema_r_t |
| new_Phi = Phi |
|
|
| return best_candidate, accepted_mask, valid_mask, new_Phi, new_ema_r_t |
|
|
| def get_best_candidate(improvement_values, cand_tokens, delta_S): |
| B, num_candidates, N = improvement_values.shape |
| device = improvement_values.device |
| best_candidate = torch.empty(B, dtype=torch.long, device=device) |
| |
| for i in range(B): |
| candidate_idx = torch.argmax(delta_S[i]) |
| best_candidate[i] = cand_tokens[i, candidate_idx] |
| |
| return best_candidate |
|
|
| def euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h): |
| B, L, V = guided_u_t.shape |
| device = x_t.device |
| u = torch.zeros_like(guided_u_t) |
|
|
| valid_mask = best_candidate != -1 |
| if valid_mask.any(): |
| valid_idx = torch.nonzero(valid_mask).squeeze(-1) |
| |
| u[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] = \ |
| guided_u_t[valid_idx, pos_indices[valid_idx], best_candidate[valid_idx]] |
| |
| |
| |
| intensity = torch.zeros(B, device=device) |
| if valid_mask.any(): |
| intensity[valid_idx] = u[valid_idx, pos_indices[valid_idx]].sum(dim=-1) |
|
|
| |
| |
| |
| |
| p_jump = 1 - torch.exp(-1 * intensity) |
| |
| rand_val = torch.rand(B, device=device) |
|
|
| jump_decision = (rand_val < p_jump) & valid_mask |
| if True in jump_decision.tolist(): |
| print("Jump!") |
| |
| x_t[jump_decision, pos_indices[jump_decision]] = best_candidate[jump_decision] |
|
|
| return x_t |
|
|