| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Pytorch version of patched decoder.""" |
| |
|
| | import dataclasses |
| | import math |
| | from typing import List, Tuple |
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def _create_quantiles() -> list[float]: |
| | return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class TimesFMConfig: |
| | """Config for initializing timesfm patched_decoder class.""" |
| |
|
| | |
| | num_layers: int = 20 |
| | |
| | num_heads: int = 16 |
| | |
| | num_kv_heads: int = 16 |
| | |
| | hidden_size: int = 1280 |
| | |
| | intermediate_size: int = 1280 |
| | |
| | head_dim: int = 80 |
| | |
| | rms_norm_eps: float = 1e-6 |
| | |
| | patch_len: int = 32 |
| | |
| | horizon_len: int = 128 |
| | |
| | quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles) |
| | |
| | pad_val: float = 1123581321.0 |
| | |
| | tolerance: float = 1e-6 |
| | |
| | dtype: str = "bfloat32" |
| | |
| | use_positional_embedding: bool = True |
| |
|
| |
|
| | def _masked_mean_std( |
| | inputs: torch.Tensor, |
| | padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Calculates mean and standard deviation of `inputs` across axis 1. |
| | |
| | It excludes values where `padding` is 1. |
| | |
| | Args: |
| | inputs: A PyTorch tensor of shape [b, n, p]. |
| | padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1. |
| | |
| | Returns: |
| | A tuple containing the mean and standard deviation. |
| | We return the statistics of the first patch with more than three non-padded |
| | values. |
| | """ |
| | |
| | pad_sum = torch.sum(1 - padding, dim=2) |
| |
|
| | def _get_patch_index(arr: torch.Tensor): |
| | indices = torch.argmax((arr >= 3).to(torch.int32), dim=1) |
| | row_sum = (arr >= 3).to(torch.int32).sum(dim=1) |
| | return torch.where(row_sum == 0, arr.shape[1] - 1, indices) |
| |
|
| | patch_indices = _get_patch_index(pad_sum) |
| | bidxs = torch.arange(inputs.shape[0]) |
| |
|
| | arr = inputs[bidxs, patch_indices, :] |
| | pad = padding[bidxs, patch_indices, :] |
| |
|
| | |
| | mask = 1 - pad |
| |
|
| | |
| | num_valid_elements = torch.sum(mask, dim=1) |
| | num_valid_elements = torch.where( |
| | num_valid_elements == 0, |
| | torch.tensor(1, |
| | dtype=num_valid_elements.dtype, |
| | device=num_valid_elements.device), |
| | num_valid_elements, |
| | ) |
| |
|
| | |
| | masked_sum = torch.sum(arr * mask, dim=1) |
| | masked_squared_sum = torch.sum((arr * mask)**2, dim=1) |
| |
|
| | |
| | masked_mean = masked_sum / num_valid_elements |
| | masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 |
| | masked_var = torch.where( |
| | masked_var < 0.0, |
| | torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), |
| | masked_var, |
| | ) |
| | masked_std = torch.sqrt(masked_var) |
| |
|
| | return masked_mean, masked_std |
| |
|
| |
|
| | def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor: |
| | """Shifts rows of seq based on the first 0 in each row of the mask. |
| | |
| | Args: |
| | mask: mask tensor of shape [B, N] |
| | seq: seq tensor of shape [B, N, P] |
| | |
| | Returns: |
| | Returns the shifted sequence. |
| | """ |
| | batch_size, num_seq, feature_dim = seq.shape |
| |
|
| | new_mask: torch.BoolTensor = mask == 0 |
| |
|
| | |
| | indices = new_mask.to(torch.int32).argmax(dim=1) |
| |
|
| | |
| | indices[~new_mask.any(dim=1)] = -1 |
| |
|
| | |
| | idx_range = (torch.arange(num_seq).to( |
| | seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, |
| | feature_dim)) |
| |
|
| | |
| | shifted_idx = (idx_range - indices[:, None, None]) % num_seq |
| |
|
| | |
| | shifted_seq = seq.gather(1, shifted_idx) |
| |
|
| | return shifted_seq |
| |
|
| |
|
| | def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor: |
| | """Returns a large negative value for the given dtype.""" |
| | if dtype.is_floating_point: |
| | dtype_max = torch.finfo(dtype).max |
| | else: |
| | dtype_max = torch.iinfo(dtype).max |
| | return torch.tensor(-0.7 * dtype_max, dtype=dtype) |
| |
|
| |
|
| | def apply_mask_to_logits(logits: torch.Tensor, |
| | mask: torch.Tensor) -> torch.Tensor: |
| | """Applies a floating-point mask to a set of logits. |
| | |
| | Args: |
| | logits: A torch.Tensor of logit values. |
| | mask: A torch.Tensor (float32) of mask values with the encoding described |
| | in the function documentation. |
| | |
| | Returns: |
| | Masked logits. |
| | """ |
| |
|
| | min_value = get_large_negative_number(logits.dtype) |
| |
|
| | return torch.where((mask >= min_value * 0.5), logits, min_value) |
| |
|
| |
|
| | def convert_paddings_to_mask( |
| | paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: |
| | """Converts binary paddings to a logit mask ready to add to attention matrix. |
| | |
| | Args: |
| | paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding |
| | token. |
| | dtype: data type of the input. |
| | |
| | Returns: |
| | A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits. |
| | """ |
| | attention_mask = paddings.detach().clone() |
| | attention_mask = attention_mask[:, None, None, :] |
| | attention_mask *= get_large_negative_number(dtype) |
| | return attention_mask |
| |
|
| |
|
| | def causal_mask(input_t: torch.Tensor) -> torch.Tensor: |
| | """Computes and returns causal mask. |
| | |
| | Args: |
| | input_t: A torch.Tensor of shape [B, T, D]. |
| | |
| | Returns: |
| | An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has |
| | already been converted to large negative values. |
| | """ |
| | assert input_t.dtype.is_floating_point, input_t.dtype |
| | large_negative_number = get_large_negative_number(input_t.dtype) |
| | t = input_t.shape[1] |
| | col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1) |
| | row_idx = torch.arange(t).unsqueeze(1).repeat(1, t) |
| | mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number |
| | return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device) |
| | ) |
| |
|
| |
|
| | def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| | """Merges 2 masks. |
| | |
| | logscale mask is expected but 0/1 mask is also fine. |
| | |
| | Args: |
| | a: torch.Tensor of shape [1|B, 1, 1|T, S]. |
| | b: torch.Tensor of shape [1|B, 1, 1|T, S]. |
| | |
| | Returns: |
| | torch.Tensor of shape [1|B, 1, 1|T, S]. |
| | """ |
| |
|
| | def expand_t(key_mask): |
| | query_mask = key_mask.transpose(-1, -2) |
| | return torch.minimum(query_mask, key_mask) |
| |
|
| | if a.shape[2] != b.shape[2]: |
| | if a.shape[2] == 1: |
| | a = expand_t(a) |
| | else: |
| | assert b.shape[2] == 1 |
| | b = expand_t(b) |
| |
|
| | assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}." |
| | return torch.minimum(a, b) |
| |
|
| |
|
| | class ResidualBlock(nn.Module): |
| | """TimesFM residual block.""" |
| |
|
| | def __init__( |
| | self, |
| | input_dims, |
| | hidden_dims, |
| | output_dims, |
| | ): |
| | super(ResidualBlock, self).__init__() |
| | self.input_dims = input_dims |
| | self.hidden_dims = hidden_dims |
| | self.output_dims = output_dims |
| |
|
| | |
| | self.hidden_layer = nn.Sequential( |
| | nn.Linear(input_dims, hidden_dims), |
| | nn.SiLU(), |
| | ) |
| |
|
| | |
| | self.output_layer = nn.Linear(hidden_dims, output_dims) |
| | |
| | self.residual_layer = nn.Linear(input_dims, output_dims) |
| |
|
| | def forward(self, x): |
| | hidden = self.hidden_layer(x) |
| | output = self.output_layer(hidden) |
| | residual = self.residual_layer(x) |
| | return output + residual |
| |
|
| |
|
| | class RMSNorm(torch.nn.Module): |
| | """Pax rms norm in pytorch.""" |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | eps: float = 1e-6, |
| | add_unit_offset: bool = False, |
| | ): |
| | super().__init__() |
| | self.eps = eps |
| | self.add_unit_offset = add_unit_offset |
| | self.weight = nn.Parameter(torch.zeros(dim)) |
| |
|
| | def _norm(self, x): |
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| |
|
| | def forward(self, x): |
| | output = self._norm(x.float()) |
| | if self.add_unit_offset: |
| | output = output * (1 + self.weight.float()) |
| | else: |
| | output = output * self.weight.float() |
| | return output.type_as(x) |
| |
|
| |
|
| | class TransformerMLP(nn.Module): |
| | """Pax transformer MLP in pytorch.""" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | intermediate_size: int, |
| | ): |
| | super().__init__() |
| | self.gate_proj = nn.Linear(hidden_size, intermediate_size) |
| | self.down_proj = nn.Linear(intermediate_size, hidden_size) |
| | self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) |
| |
|
| | def forward(self, x, paddings=None): |
| | gate_inp = self.layer_norm(x) |
| | gate = self.gate_proj(gate_inp) |
| | gate = F.relu(gate) |
| | outputs = self.down_proj(gate) |
| | if paddings is not None: |
| | outputs = outputs * (1.0 - paddings[:, :, None]) |
| | return outputs + x |
| |
|
| |
|
| | class TimesFMAttention(nn.Module): |
| | """Implements the attention used in TimesFM.""" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | num_heads: int, |
| | num_kv_heads: int, |
| | head_dim: int, |
| | ): |
| | super().__init__() |
| |
|
| | self.num_heads = num_heads |
| | self.num_kv_heads = num_kv_heads |
| |
|
| | assert self.num_heads % self.num_kv_heads == 0 |
| | self.num_queries_per_kv = self.num_heads // self.num_kv_heads |
| |
|
| | self.hidden_size = hidden_size |
| | self.head_dim = head_dim |
| |
|
| | self.q_size = self.num_heads * self.head_dim |
| | self.kv_size = self.num_kv_heads * self.head_dim |
| | self.scaling = nn.Parameter( |
| | torch.empty((self.head_dim,), dtype=torch.float32),) |
| |
|
| | self.qkv_proj = nn.Linear( |
| | self.hidden_size, |
| | (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, |
| | ) |
| | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) |
| |
|
| | def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor: |
| | |
| | r_softplus_0 = 1.442695041 |
| | softplus_func = torch.nn.Softplus() |
| | scale = r_softplus_0 / math.sqrt(self.head_dim) |
| | scale = scale * softplus_func(self.scaling) |
| | return query * scale[None, None, None, :] |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | mask: torch.Tensor, |
| | kv_write_indices: torch.Tensor | None = None, |
| | kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, |
| | ) -> torch.Tensor: |
| | hidden_states_shape = hidden_states.shape |
| | assert len(hidden_states_shape) == 3 |
| |
|
| | batch_size, input_len, _ = hidden_states_shape |
| |
|
| | qkv = self.qkv_proj(hidden_states) |
| | xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
| |
|
| | xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) |
| | xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) |
| | xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) |
| | xq = self._per_dim_scaling(xq) |
| |
|
| | |
| | |
| | if kv_cache is not None and kv_write_indices is not None: |
| | k_cache, v_cache = kv_cache |
| | k_cache.index_copy_(1, kv_write_indices, xk) |
| | v_cache.index_copy_(1, kv_write_indices, xv) |
| |
|
| | key = k_cache |
| | value = v_cache |
| | else: |
| | key = xk |
| | value = xv |
| | if self.num_kv_heads != self.num_heads: |
| | |
| | key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) |
| | value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) |
| |
|
| | |
| | q = xq.transpose(1, 2) |
| | |
| | k = key.transpose(1, 2) |
| | v = value.transpose(1, 2) |
| |
|
| | |
| | scores = torch.matmul(q, k.transpose(2, 3)) |
| | scores = scores + mask |
| | scores = F.softmax(scores.float(), dim=-1).type_as(q) |
| |
|
| | |
| | output = torch.matmul(scores, v) |
| | |
| |
|
| | |
| | output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) |
| | output = self.o_proj(output) |
| | return scores, output |
| |
|
| |
|
| | class TimesFMDecoderLayer(nn.Module): |
| | """Transformer layer.""" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | intermediate_size: int, |
| | num_heads: int, |
| | num_kv_heads: int, |
| | head_dim: int, |
| | rms_norm_eps: float = 1e-6, |
| | ): |
| | super().__init__() |
| | self.self_attn = TimesFMAttention( |
| | hidden_size=hidden_size, |
| | num_heads=num_heads, |
| | num_kv_heads=num_kv_heads, |
| | head_dim=head_dim, |
| | ) |
| | self.mlp = TransformerMLP( |
| | hidden_size=hidden_size, |
| | intermediate_size=intermediate_size, |
| | ) |
| | self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | mask: torch.Tensor, |
| | paddings: torch.Tensor, |
| | kv_write_indices: torch.Tensor | None = None, |
| | kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None, |
| | ) -> torch.Tensor: |
| | |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| | scores, hidden_states = self.self_attn( |
| | hidden_states=hidden_states, |
| | mask=mask, |
| | kv_write_indices=kv_write_indices, |
| | kv_cache=kv_cache, |
| | ) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | hidden_states = self.mlp(hidden_states, paddings=paddings) |
| |
|
| | return scores, hidden_states |
| |
|
| |
|
| | class StackedDecoder(nn.Module): |
| | """Stacked transformer layer.""" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | intermediate_size: int, |
| | num_heads: int, |
| | num_kv_heads: int, |
| | head_dim: int, |
| | num_layers: int, |
| | rms_norm_eps: float = 1e-6, |
| | ): |
| | super().__init__() |
| |
|
| | self.layers = nn.ModuleList() |
| | for _ in range(num_layers): |
| | self.layers.append( |
| | TimesFMDecoderLayer( |
| | hidden_size=hidden_size, |
| | intermediate_size=intermediate_size, |
| | num_heads=num_heads, |
| | num_kv_heads=num_kv_heads, |
| | head_dim=head_dim, |
| | rms_norm_eps=rms_norm_eps, |
| | )) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | paddings: torch.Tensor, |
| | kv_write_indices: torch.Tensor | None = None, |
| | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None, |
| | ) -> torch.Tensor: |
| | padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype) |
| | atten_mask = causal_mask(hidden_states) |
| | mask = merge_masks(padding_mask, atten_mask) |
| | for i in range(len(self.layers)): |
| | layer = self.layers[i] |
| | kv_cache = kv_caches[i] if kv_caches is not None else None |
| | _, hidden_states = layer( |
| | hidden_states=hidden_states, |
| | mask=mask, |
| | paddings=paddings, |
| | kv_write_indices=kv_write_indices, |
| | kv_cache=kv_cache, |
| | ) |
| | return hidden_states |
| |
|
| |
|
| | class PositionalEmbedding(torch.nn.Module): |
| | """Generates position embedding for a given 1-d sequence. |
| | |
| | Attributes: |
| | min_timescale: Start of the geometric index. Determines the periodicity of |
| | the added signal. |
| | max_timescale: End of the geometric index. Determines the frequency of the |
| | added signal. |
| | embedding_dims: Dimension of the embedding to be generated. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embedding_dims: int, |
| | min_timescale: int = 1, |
| | max_timescale: int = 10_000, |
| | ) -> None: |
| | super().__init__() |
| | self.min_timescale = min_timescale |
| | self.max_timescale = max_timescale |
| | self.embedding_dims = embedding_dims |
| |
|
| | def forward(self, seq_length=None, position=None): |
| | """Generates a Tensor of sinusoids with different frequencies. |
| | |
| | Args: |
| | seq_length: an optional Python int defining the output sequence length. |
| | if the `position` argument is specified. |
| | position: [B, seq_length], optional position for each token in the |
| | sequence, only required when the sequence is packed. |
| | |
| | Returns: |
| | [B, seqlen, D] if `position` is specified, else [1, seqlen, D] |
| | """ |
| | if position is None: |
| | assert seq_length is not None |
| | |
| | position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0) |
| | else: |
| | assert position.ndim == 2, position.shape |
| |
|
| | num_timescales = self.embedding_dims // 2 |
| | log_timescale_increment = math.log( |
| | float(self.max_timescale) / float(self.min_timescale)) / max( |
| | num_timescales - 1, 1) |
| | inv_timescales = self.min_timescale * torch.exp( |
| | torch.arange(num_timescales, dtype=torch.float32) * |
| | -log_timescale_increment) |
| | scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze( |
| | 0) |
| | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) |
| | |
| | signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) |
| | return signal |
| |
|
| |
|
| | class PatchedTimeSeriesDecoder(nn.Module): |
| | """Patched time-series decoder.""" |
| |
|
| | def __init__(self, config: TimesFMConfig): |
| | super().__init__() |
| | self.config = config |
| | self.input_ff_layer = ResidualBlock( |
| | input_dims=2 * config.patch_len, |
| | output_dims=config.hidden_size, |
| | hidden_dims=config.intermediate_size, |
| | ) |
| | self.freq_emb = nn.Embedding(num_embeddings=3, |
| | embedding_dim=config.hidden_size) |
| | self.horizon_ff_layer = ResidualBlock( |
| | input_dims=config.hidden_size, |
| | output_dims=config.horizon_len * (1 + len(config.quantiles)), |
| | hidden_dims=config.intermediate_size, |
| | ) |
| | self.stacked_transformer = StackedDecoder( |
| | hidden_size=self.config.hidden_size, |
| | intermediate_size=self.config.intermediate_size, |
| | num_heads=self.config.num_heads, |
| | num_kv_heads=self.config.num_kv_heads, |
| | head_dim=self.config.head_dim, |
| | num_layers=self.config.num_layers, |
| | rms_norm_eps=self.config.rms_norm_eps, |
| | ) |
| | if self.config.use_positional_embedding: |
| | self.position_emb = PositionalEmbedding(self.config.hidden_size) |
| |
|
| | def _forward_transform( |
| | self, inputs: torch.Tensor, patched_pads: torch.Tensor |
| | ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
| | """Input is of shape [B, N, P].""" |
| | mu, sigma = _masked_mean_std(inputs, patched_pads) |
| | sigma = torch.where( |
| | sigma < self.config.tolerance, |
| | torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), |
| | sigma, |
| | ) |
| |
|
| | |
| | outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] |
| | outputs = torch.where( |
| | torch.abs(inputs - self.config.pad_val) < self.config.tolerance, |
| | torch.tensor(self.config.pad_val, |
| | dtype=outputs.dtype, |
| | device=outputs.device), |
| | outputs, |
| | ) |
| | return outputs, (mu, sigma) |
| |
|
| | def _reverse_transform( |
| | self, outputs: torch.Tensor, stats: tuple[torch.Tensor, |
| | torch.Tensor]) -> torch.Tensor: |
| | """Output is of shape [B, N, P, Q].""" |
| | mu, sigma = stats |
| | return outputs * sigma[:, None, None, None] + mu[:, None, None, None] |
| |
|
| | def _preprocess_input( |
| | self, |
| | input_ts: torch.Tensor, |
| | input_padding: torch.Tensor, |
| | ) -> tuple[ |
| | torch.Tensor, |
| | torch.Tensor, |
| | tuple[torch.Tensor, torch.Tensor] | None, |
| | torch.Tensor, |
| | ]: |
| | """Preprocess input for stacked transformer.""" |
| |
|
| | |
| | bsize = input_ts.shape[0] |
| | patched_inputs = input_ts.view(bsize, -1, self.config.patch_len) |
| | patched_pads = input_padding.view(bsize, -1, self.config.patch_len) |
| |
|
| | patched_inputs = torch.where( |
| | torch.abs(patched_pads - 1.0) < self.config.tolerance, |
| | torch.tensor(0.0, |
| | dtype=patched_inputs.dtype, |
| | device=patched_inputs.device), |
| | patched_inputs, |
| | ) |
| | patched_pads = torch.where( |
| | torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, |
| | torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device), |
| | patched_pads, |
| | ) |
| | patched_inputs, stats = self._forward_transform(patched_inputs, |
| | patched_pads) |
| |
|
| | |
| | patched_inputs = patched_inputs * (1.0 - patched_pads) |
| | concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1) |
| | model_input = self.input_ff_layer(concat_inputs) |
| |
|
| | |
| | patched_padding = torch.min(patched_pads, |
| | dim=-1)[0] |
| | if self.config.use_positional_embedding: |
| | pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device) |
| | pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0) |
| | pos_emb = _shift_padded_seq(patched_padding, pos_emb) |
| | model_input += pos_emb |
| |
|
| | return model_input, patched_padding, stats, patched_inputs |
| |
|
| | def _postprocess_output( |
| | self, |
| | model_output: torch.Tensor, |
| | num_outputs: int, |
| | stats: tuple[torch.Tensor, torch.Tensor], |
| | ) -> torch.Tensor: |
| | """Postprocess output of stacked transformer.""" |
| |
|
| | |
| | output_ts = self.horizon_ff_layer(model_output) |
| |
|
| | |
| | b, n, _ = output_ts.shape |
| | output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs) |
| |
|
| | return self._reverse_transform(output_ts, stats) |
| |
|
| | def forward( |
| | self, |
| | input_ts: torch.Tensor, |
| | input_padding: torch.LongTensor, |
| | freq: torch.Tensor, |
| | ) -> torch.Tensor: |
| | num_outputs = len(self.config.quantiles) + 1 |
| | model_input, patched_padding, stats, _ = self._preprocess_input( |
| | input_ts=input_ts, |
| | input_padding=input_padding, |
| | ) |
| | f_emb = self.freq_emb(freq) |
| | model_input += f_emb |
| | model_output = self.stacked_transformer(model_input, patched_padding) |
| |
|
| | output_ts = self._postprocess_output(model_output, num_outputs, stats) |
| | return output_ts |
| |
|
| | def decode( |
| | self, |
| | input_ts: torch.Tensor, |
| | paddings: torch.Tensor, |
| | freq: torch.LongTensor, |
| | horizon_len: int, |
| | output_patch_len: int | None = None, |
| | max_len: int = 512, |
| | return_forecast_on_context: bool = False, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Auto-regressive decoding without caching. |
| | |
| | Args: |
| | input_ts: input time-series and paddings. Time-series shape B x C. |
| | paddings: padding shape B x (C + H) where H is the prediction length. |
| | freq: frequency shape B x 1 |
| | horizon_len: prediction length. |
| | output_patch_len: output length to be fetched from one step of |
| | auto-regressive decoding. |
| | max_len: maximum training context length. |
| | return_forecast_on_context: whether to return the model forecast on the |
| | context except the first input patch. |
| | |
| | Returns: |
| | Tuple of two forecasting results: |
| | - Point (mean) output predictions as a tensor with shape B x H'. |
| | - Full predictions (mean and quantiles) as a tensor with shape |
| | B x H' x (1 + # quantiles). |
| | In particular, if return_forecast_on_context is True, H' is H plus |
| | the forecastable context length, i.e. context_len - (first) patch_len. |
| | """ |
| | final_out = input_ts |
| | context_len = final_out.shape[1] |
| | full_outputs = [] |
| | if paddings.shape[1] != final_out.shape[1] + horizon_len: |
| | raise ValueError( |
| | "Length of paddings must match length of input + horizon_len:" |
| | f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}") |
| | if output_patch_len is None: |
| | output_patch_len = self.config.horizon_len |
| | num_decode_patches = (horizon_len + output_patch_len - |
| | 1) // output_patch_len |
| | for step_index in range(num_decode_patches): |
| | current_padding = paddings[:, 0:final_out.shape[1]] |
| | input_ts = final_out[:, -max_len:] |
| | input_padding = current_padding[:, -max_len:] |
| | fprop_outputs = self(input_ts, input_padding, freq) |
| | if return_forecast_on_context and step_index == 0: |
| | |
| | |
| | new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :] |
| | new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1, |
| | new_full_ts.size(3)) |
| |
|
| | full_outputs.append(new_full_ts) |
| |
|
| | |
| | new_ts = fprop_outputs[:, -1, :output_patch_len, 0] |
| | new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] |
| | |
| | full_outputs.append(new_full_ts) |
| | final_out = torch.concatenate([final_out, new_ts], axis=-1) |
| |
|
| | if return_forecast_on_context: |
| | |
| | full_outputs = torch.concatenate( |
| | full_outputs, |
| | axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :] |
| | else: |
| | |
| | full_outputs = torch.concatenate(full_outputs, axis=1)[:, |
| | 0:horizon_len, :] |
| |
|
| | return (full_outputs[:, :, 0], full_outputs) |
| | |
| | class TimesFM(nn.Module): |
| |
|
| | def __init__(self, lookback: int = 512, lookahead: int = 96, context_len: int = 512): |
| |
|
| | super(TimesFM, self).__init__() |
| | |
| | self.timesfm = PatchedTimeSeriesDecoder(TimesFMConfig()) |
| | self.lookback, self.lookahead = lookback, lookahead |
| | self.context_len = context_len |
| |
|
| | def load_state_dict(self, state_dict, *args, **kwargs): |
| |
|
| | return self.timesfm.load_state_dict(state_dict, *args, **kwargs) |
| |
|
| | def state_dict(self, *args, **kwargs): |
| |
|
| | return self.timesfm.state_dict(*args, **kwargs) |
| | |
| | def pad_tensor(self, x): |
| |
|
| | B, L = x.shape |
| | device = x.device |
| | dtype = x.dtype |
| | |
| | if L < self.context_len: |
| | padded_input = torch.zeros((B, self.context_len), device=device, dtype=dtype) |
| | padded_input[:, -L:] = x |
| | padding = torch.ones((B, self.context_len), device=device, dtype=dtype) |
| | padding[:, -L:] = 0 |
| | else: |
| | padded_input = x[:, -self.context_len:] |
| | padding = torch.zeros((B, self.context_len), device=device, dtype=dtype) |
| | |
| | freq = torch.zeros((B, 1), device=device, dtype=torch.long) |
| | |
| | return padded_input, torch.cat((padding,torch.zeros((B,self.lookahead),device=device,dtype=dtype)),dim=-1), freq |
| | |
| | def forward(self, x): |
| |
|
| | padded_inp, padding, freq = self.pad_tensor(x) |
| | return self.timesfm.decode(padded_inp,padding,freq,self.lookahead)[0] |