| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutput |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GlobalConv1D(nn.Module): |
| | def __init__(self, d_model, kernel_size, fft_size): |
| | super().__init__() |
| | self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
| | self.kernel_size = kernel_size |
| | self.fft_size = fft_size |
| |
|
| | def forward(self, x): |
| | |
| | B, C, T = x.shape |
| | K = min(self.kernel_size, T) |
| |
|
| | overlap = K - 1 |
| | block = self.fft_size - overlap |
| |
|
| | x = F.pad(x, (overlap, 0)) |
| | k = self.kernel[:, :K] |
| | k = F.pad(k, (0, self.fft_size - K)) |
| |
|
| | k_f = torch.fft.rfft(k, n=self.fft_size) |
| |
|
| | outs = [] |
| | pos = 0 |
| | while pos < T: |
| | seg = x[..., pos:pos + self.fft_size] |
| | if seg.shape[-1] < self.fft_size: |
| | seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
| |
|
| | y = torch.fft.irfft( |
| | torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), |
| | n=self.fft_size |
| | ) |
| | outs.append(y[..., overlap:overlap + block]) |
| | pos += block |
| |
|
| | return torch.cat(outs, dim=-1)[..., :T] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LocalConv1D(nn.Module): |
| | def __init__(self, d_model, k): |
| | super().__init__() |
| | self.k = k |
| | self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
| | self.pw = nn.Conv1d(d_model, d_model, 1) |
| |
|
| | def forward(self, x): |
| | x = F.pad(x, (self.k - 1, 0)) |
| | return self.pw(F.relu(self.dw(x))) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GCLMBlock(nn.Module): |
| | def __init__(self, config, use_global): |
| | super().__init__() |
| | self.use_global = use_global |
| |
|
| | self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) |
| | self.local = LocalConv1D( |
| | config.d_model, |
| | config.local_kernel_size |
| | ) |
| |
|
| | if use_global: |
| | self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) |
| | self.global_conv = GlobalConv1D( |
| | config.d_model, |
| | config.global_kernel_size, |
| | config.fft_size |
| | ) |
| |
|
| | self.ln3 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) |
| | self.ff = nn.Sequential( |
| | nn.Linear(config.d_model, config.d_model * 4), |
| | nn.GELU(), |
| | nn.Linear(config.d_model * 4, config.d_model), |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.local(self.ln1(x).transpose(1, 2)).transpose(1, 2) |
| | if self.use_global: |
| | x = x + self.global_conv(self.ln2(x).transpose(1, 2)).transpose(1, 2) |
| | return x + self.ff(self.ln3(x)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GCLMModel(PreTrainedModel): |
| | config_class = None |
| | base_model_prefix = "gclm" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.emb = nn.Embedding(config.vocab_size, config.d_model) |
| | self.pos = nn.Embedding(config.max_seq_len, config.d_model) |
| |
|
| | self.layers = nn.ModuleList([ |
| | GCLMBlock( |
| | config, |
| | use_global=(i % config.use_global_every_n_layers == 0) |
| | ) |
| | for i in range(config.n_layers) |
| | ]) |
| |
|
| | self.ln = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, input_ids): |
| | B, T = input_ids.shape |
| | pos = torch.arange(T, device=input_ids.device) |
| |
|
| | h = self.emb(input_ids) + self.pos(pos) |
| |
|
| | for layer in self.layers: |
| | h = layer(h) |
| |
|
| | return self.ln(h) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GCLMForCausalLM(PreTrainedModel): |
| | config_class = None |
| | base_model_prefix = "gclm" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.gclm = GCLMModel(config) |
| | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids, |
| | labels=None, |
| | **kwargs |
| | ): |
| | hidden = self.gclm(input_ids) |
| | logits = self.lm_head(hidden) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | labels.view(-1), |
| | ignore_index=-100 |
| | ) |
| |
|
| | return CausalLMOutput( |
| | loss=loss, |
| | logits=logits |
| | ) |
| |
|