| | --- |
| | license: mit |
| | --- |
| |  |
| | # Usage |
| |
|
| | **Instantiate the Base Model** |
| | ```python |
| | from braindecode.models import SignalJEPA |
| | from huggingface_hub import hf_hub_download |
| | |
| | weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") |
| | model_state_dict = torch.load(weights_path) |
| | |
| | # Signal-related arguments |
| | # raw: mne.io.BaseRaw |
| | chs_info = raw.info["chs"] |
| | sfreq = raw.info["sfreq"] |
| | |
| | model = SignalJEPA( |
| | sfreq=sfreq, |
| | input_window_seconds=2, |
| | chs_info=chs_info, |
| | ) |
| | missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) |
| | assert unexpected_keys == [] |
| | # The spatial positional encoder is initialized using the `chs_info`: |
| | assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"} |
| | ``` |
| |
|
| | **Instantiate the Downstream Architectures** |
| |
|
| | Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. |
| | Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures: |
| | - a) Contextual downstream architecture |
| | - b) Post-local downstream architecture |
| | - c) Pre-local architecture |
| |
|
| | ```python |
| | from braindecode.models import ( |
| | SignalJEPA_Contextual, |
| | SignalJEPA_PreLocal, |
| | SignalJEPA_PostLocal, |
| | ) |
| | from huggingface_hub import hf_hub_download |
| | |
| | weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") |
| | model_state_dict = torch.load(weights_path) |
| | |
| | # Signal-related arguments |
| | # raw: mne.io.BaseRaw |
| | chs_info = raw.info["chs"] |
| | sfreq = raw.info["sfreq"] |
| | |
| | # The downstream architectures are equipped with an additional classification head |
| | # which was not pre-trained. It has the following new parameters: |
| | final_layer_keys = { |
| | "final_layer.spat_conv.weight", |
| | "final_layer.spat_conv.bias", |
| | "final_layer.linear.weight", |
| | "final_layer.linear.bias", |
| | } |
| | |
| | |
| | # a) Contextual downstream architecture |
| | # ---------------------------------- |
| | model = SignalJEPA_Contextual( |
| | sfreq=sfreq, |
| | input_window_seconds=2, |
| | chs_info=chs_info, |
| | n_outputs=1, |
| | ) |
| | missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) |
| | assert unexpected_keys == [] |
| | # The spatial positional encoder is initialized using the `chs_info`: |
| | assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"} |
| | |
| | # In the post-local (b) and pre-local (c) architectures, the transformer is discarded: |
| | FILTERED_model_state_dict = { |
| | k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."]) |
| | } |
| | |
| | |
| | # b) Post-local downstream architecture |
| | # ---------------------------------- |
| | model = SignalJEPA_PostLocal( |
| | sfreq=sfreq, |
| | input_window_seconds=2, |
| | n_chans=len(chs_info), # detailed channel info is not needed for this model |
| | n_outputs=1, |
| | ) |
| | missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) |
| | assert unexpected_keys == [] |
| | assert set(missing_keys) == final_layer_keys |
| | |
| | |
| | # c) Pre-local architecture |
| | # ---------------------- |
| | model = SignalJEPA_PreLocal( |
| | sfreq=sfreq, |
| | input_window_seconds=2, |
| | n_chans=len(chs_info), # detailed channel info is not needed for this model |
| | n_outputs=1, |
| | ) |
| | missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) |
| | assert unexpected_keys == [] |
| | assert set(missing_keys) == { |
| | "spatial_conv.1.weight", |
| | "spatial_conv.1.bias", |
| | "final_layer.1.weight", |
| | "final_layer.1.bias", |
| | } |
| | ``` |