| import argparse |
| import os |
| import torch |
| from exp.exp_main import Exp_Main |
| import random |
| import json |
| import numpy as np |
| from torch.utils.tensorboard import SummaryWriter |
| import traceback |
| import pathlib |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from torch.fft import rfft, irfft |
|
|
| class moving_avg(nn.Module): |
| """ |
| Moving average block to highlight the trend of time series with boundary adjustment |
| """ |
| def __init__(self, kernel_size, stride): |
| super(moving_avg, self).__init__() |
| self.kernel_size = kernel_size |
| self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) |
|
|
| def forward(self, x): |
| |
| front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) |
| end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) |
| x = torch.cat([front, x, end], dim=1) |
| x = self.avg(x.permute(0, 2, 1)) |
| x = x.permute(0, 2, 1) |
| return x |
|
|
|
|
| class series_decomp(nn.Module): |
| """ |
| Enhanced series decomposition block with adaptive frequency selection |
| """ |
| def __init__(self, kernel_size, freq_range=5, filter_strength=0.5, top_k=3): |
| super(series_decomp, self).__init__() |
| self.moving_avg = moving_avg(kernel_size, stride=1) |
| self.freq_range = freq_range |
| self.filter_strength = filter_strength |
| self.top_k = top_k |
| |
| def _enhance_seasonal(self, seasonal): |
| """Apply advanced frequency enhancement to seasonal component""" |
| |
| seasonal_fft = rfft(seasonal.permute(0, 2, 1), dim=2) |
| power = torch.abs(seasonal_fft)**2 |
| |
| |
| avg_power = torch.mean(power, dim=(0, 1)) |
| |
| |
| if len(avg_power) > self.top_k: |
| |
| _, top_indices = torch.topk(avg_power, self.top_k) |
| |
| |
| mask = torch.ones_like(seasonal_fft) * (1 - self.filter_strength) |
| |
| |
| for idx in top_indices: |
| start_idx = max(0, idx - self.freq_range) |
| end_idx = min(len(avg_power), idx + self.freq_range + 1) |
| |
| |
| for i in range(start_idx, end_idx): |
| |
| distance = abs(i - idx) |
| weight = 1.0 - (distance / (self.freq_range + 1)) |
| |
| |
| mask[:, :, i] += weight * self.filter_strength |
| |
| |
| filtered_fft = seasonal_fft * mask |
| enhanced_seasonal = irfft(filtered_fft, dim=2, n=seasonal.size(1)) |
| return enhanced_seasonal.permute(0, 2, 1) |
| |
| |
| total_power = torch.sum(avg_power) |
| if total_power > 0: |
| freq_weights = avg_power / total_power |
| |
| freq_weights = freq_weights ** 0.3 |
| |
| |
| mask = torch.ones_like(seasonal_fft) * (1 - self.filter_strength) |
| for i in range(len(freq_weights)): |
| mask[:, :, i] += freq_weights[i] * self.filter_strength |
| |
| |
| filtered_fft = seasonal_fft * mask |
| enhanced_seasonal = irfft(filtered_fft, dim=2, n=seasonal.size(1)) |
| return enhanced_seasonal.permute(0, 2, 1) |
| |
| return seasonal |
|
|
| def forward(self, x): |
| |
| moving_mean = self.moving_avg(x) |
| |
| |
| seasonal = x - moving_mean |
| |
| |
| enhanced_seasonal = self._enhance_seasonal(seasonal) |
| |
| |
| |
| final_seasonal = seasonal * 0.8 + enhanced_seasonal * 0.2 |
| |
| return final_seasonal, moving_mean |
|
|
| |
|
|
|
|
| class SimpleTrendAttention(nn.Module): |
| """ |
| Simple attention mechanism for trend component |
| """ |
| def __init__(self, seq_len): |
| super(SimpleTrendAttention, self).__init__() |
| |
| self.attention = nn.Parameter(torch.ones(seq_len) / seq_len) |
| |
| def forward(self, x): |
| |
| |
| weights = F.softmax(self.attention, dim=0) |
| |
| weights = weights.view(1, -1, 1) |
| |
| return x * weights |
|
|
|
|
| class AdaptiveHybridDFTNet(nn.Module): |
| """ |
| Refined AdaptiveHybridDFTNet with balanced components |
| """ |
| def __init__(self, configs): |
| super(AdaptiveHybridDFTNet, self).__init__() |
| self.seq_len = configs.seq_len |
| self.pred_len = configs.pred_len |
| self.channels = configs.enc_in |
| self.individual = configs.individual |
| |
| |
| kernel_size = min(25, max(5, self.seq_len // 8)) |
| kernel_size = configs.moving_avg if hasattr(configs, 'moving_avg') else kernel_size |
| |
| |
| freq_range = configs.freq_range if hasattr(configs, 'freq_range') else 5 |
| filter_strength = configs.filter_strength if hasattr(configs, 'filter_strength') else 0.2 |
| top_k = configs.top_k if hasattr(configs, 'top_k') else 3 |
| |
| |
| self.decomposition = series_decomp(kernel_size, freq_range, filter_strength, top_k) |
| |
| |
| self.trend_attention = SimpleTrendAttention(self.seq_len) |
| |
| |
| if self.individual: |
| self.Linear_Seasonal = nn.ModuleList() |
| self.Linear_Trend = nn.ModuleList() |
| |
| for i in range(self.channels): |
| self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len)) |
| self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len)) |
| else: |
| self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) |
| self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) |
| |
| |
| self.seasonal_weight = nn.Parameter(torch.tensor(0.5)) |
| self.trend_weight = nn.Parameter(torch.tensor(0.5)) |
| |
| def forward(self, x): |
| |
| |
| |
| seasonal, trend = self.decomposition(x) |
| |
| |
| trend = self.trend_attention(trend) |
| |
| |
| seasonal = seasonal.permute(0, 2, 1) |
| trend = trend.permute(0, 2, 1) |
| |
| |
| if self.individual: |
| seasonal_output = torch.zeros([seasonal.size(0), self.pred_len, self.channels], |
| dtype=seasonal.dtype).to(seasonal.device) |
| trend_output = torch.zeros([trend.size(0), self.pred_len, self.channels], |
| dtype=trend.dtype).to(trend.device) |
| |
| for i in range(self.channels): |
| seasonal_output[:, :, i] = self.Linear_Seasonal[i](seasonal[:, i, :]) |
| trend_output[:, :, i] = self.Linear_Trend[i](trend[:, i, :]) |
| else: |
| seasonal_output = self.Linear_Seasonal(seasonal) |
| trend_output = self.Linear_Trend(trend) |
| |
| |
| seasonal_output = seasonal_output.permute(0, 2, 1) |
| trend_output = trend_output.permute(0, 2, 1) |
| |
| |
| total_weight = torch.abs(self.seasonal_weight) + torch.abs(self.trend_weight) |
| seasonal_weight_norm = torch.abs(self.seasonal_weight) / total_weight |
| trend_weight_norm = torch.abs(self.trend_weight) / total_weight |
| |
| |
| x = seasonal_output * seasonal_weight_norm + trend_output * trend_weight_norm |
| |
| return x |
|
|
|
|
| |
| class Model(AdaptiveHybridDFTNet): |
| """ |
| Wrapper class for backward compatibility |
| """ |
| def __init__(self, configs): |
| super(Model, self).__init__(configs) |
|
|
|
|
| if __name__ == '__main__': |
| fix_seed = 2021 |
| random.seed(fix_seed) |
| torch.manual_seed(fix_seed) |
| np.random.seed(fix_seed) |
|
|
| parser = argparse.ArgumentParser(description='Autoformer & Transformer family for Time Series Forecasting') |
| parser.add_argument("--out_dir", type=str, default="run_0") |
| |
| |
| parser.add_argument('--is_training', type=int, required=True, default=1, help='status') |
| parser.add_argument('--train_only', type=bool, required=False, default=False, help='perform training on full input dataset without validation and testing') |
|
|
| |
| parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type') |
| parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file') |
| parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file') |
| parser.add_argument('--features', type=str, default='M', |
| help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') |
| parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') |
| parser.add_argument('--freq', type=str, default='h', |
| help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') |
| parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') |
|
|
| |
| parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') |
| parser.add_argument('--label_len', type=int, default=48, help='start token length') |
| parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') |
|
|
|
|
| |
| parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually') |
| |
| parser.add_argument('--embed_type', type=int, default=0, help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding') |
| parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') |
| parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') |
| parser.add_argument('--c_out', type=int, default=7, help='output size') |
| parser.add_argument('--d_model', type=int, default=512, help='dimension of model') |
| parser.add_argument('--n_heads', type=int, default=8, help='num of heads') |
| parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') |
| parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') |
| parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn') |
| parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average for trend extraction') |
| parser.add_argument('--freq_range', type=int, default=5, help='frequency range for adaptive DFT selection') |
| parser.add_argument('--filter_strength', type=float, default=0.2, help='strength of frequency filtering (0-1)') |
| parser.add_argument('--top_k', type=int, default=3, help='number of top frequencies to enhance') |
| parser.add_argument('--factor', type=int, default=1, help='attn factor') |
| parser.add_argument('--distil', action='store_false', |
| help='whether to use distilling in encoder, using this argument means not using distilling', |
| default=True) |
| parser.add_argument('--dropout', type=float, default=0.05, help='dropout') |
| parser.add_argument('--embed', type=str, default='timeF', |
| help='time features encoding, options:[timeF, fixed, learned]') |
| parser.add_argument('--activation', type=str, default='gelu', help='activation') |
| parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') |
| parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') |
|
|
| |
| parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') |
| parser.add_argument('--itr', type=int, default=2, help='experiments times') |
| parser.add_argument('--train_epochs', type=int, default=10, help='train epochs') |
| parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') |
| parser.add_argument('--patience', type=int, default=3, help='early stopping patience') |
| parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') |
| parser.add_argument('--des', type=str, default='test', help='exp description') |
| parser.add_argument('--loss', type=str, default='mse', help='loss function') |
| parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') |
| parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) |
|
|
| |
| parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') |
| parser.add_argument('--gpu', type=int, default=0, help='gpu') |
| parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) |
| parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') |
| parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage') |
|
|
| args = parser.parse_args() |
| try: |
| log_dir = os.path.join(args.out_dir, 'logs') |
| pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) |
| writer = SummaryWriter(log_dir) |
| args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False |
|
|
| if args.use_gpu and args.use_multi_gpu: |
| args.dvices = args.devices.replace(' ', '') |
| device_ids = args.devices.split(',') |
| args.device_ids = [int(id_) for id_ in device_ids] |
| args.gpu = args.device_ids[0] |
|
|
| print('Args in experiment:') |
| print(args) |
| mse,mae = [], [] |
| pred_lens = [96, 192, 336, 720] if args.data_path != 'illness.csv' else [24, 36, 48, 60] |
| for pred_len in pred_lens: |
| args.pred_len = pred_len |
| model = Model(args) |
| Exp = Exp_Main |
| setting = '{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}'.format( |
| args.data, |
| args.features, |
| args.seq_len, |
| args.label_len, |
| pred_len, |
| args.d_model, |
| args.n_heads, |
| args.e_layers, |
| args.d_layers, |
| args.d_ff, |
| args.factor, |
| args.embed, |
| args.distil, |
| args.des) |
|
|
| exp = Exp(args,model) |
| print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) |
| exp.train(setting,writer) |
| print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) |
| single_mae, single_mse = exp.test(setting) |
| print('mse:{}, mae:{}'.format(single_mse, single_mae)) |
| mae.append(single_mae) |
| mse.append(single_mse) |
| torch.cuda.empty_cache() |
| mean_mae = sum(mae) / len(mae) |
| mean_mse = sum(mse) / len(mse) |
| final_infos = { |
| args.data :{ |
| "means":{ |
| "mae": mean_mae, |
| "mse": mean_mse, |
| } |
| } |
| } |
| pathlib.Path(args.out_dir).mkdir(parents=True, exist_ok=True) |
| |
| with open(os.path.join(args.out_dir, f"final_info.json"), "w") as f: |
| json.dump(final_infos, f) |
| |
| except Exception as e: |
| print("Original error in subprocess:", flush=True) |
| traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
| raise |
|
|