Skip to main content
The MRL model is a Matryoshka-distilled version of EPI-250k. It maps a 30-second multi-channel EEG epoch to an L2-normalized vector. The 768d output can be truncated to any prefix of [768, 384, 192, 48, 16] dimensions without re-running the model.

Loading the model

from neuroencoder import MRL

model = MRL.from_pretrained(device="cuda")
The model auto-downloads from HuggingFace on first call (~215 MB cached locally). Pass token=... for explicit authentication, or run huggingface-cli login once. From a local Lightning checkpoint:
model = MRL.from_checkpoint("path/to/last.ckpt")

One-call usage

embeddings = model.embed(eeg, sfreq=256, channel_names=ch_names, dim=192)
# -> numpy array, shape [N, 192], L2-normalized
model.embed chains preprocessing + prediction and returns a numpy array — the most common shape users want for downstream sklearn / FAISS work.

Step-by-step

import neuroencoder as ne

images = ne.preprocess(eeg, sfreq=256, channel_names=ch_names)  # [N, 8, 224, 224]
embeddings = model.predict(images, dim=192)                      # [N, 192]
model.predict() returns L2-normalized embeddings as a torch tensor on the model’s device — cosine similarity reduces to a dot product. Valid dimensions: 768, 384, 192, 48, 16. See Benchmarks for accuracy at each dimension. Input is auto-moved to the model device.

Matryoshka

Compute once at full resolution, truncate later:
import torch.nn.functional as F

full = model.predict(images, dim=768)              # store this
compact = F.normalize(full[:, :48], dim=-1)        # for retrieval
tiny = F.normalize(full[:, :16], dim=-1)           # for storage-constrained
model.predict() normalizes for you. Manual truncation requires re-normalization for cosine distance to work correctly.

Classification

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import balanced_accuracy_score

emb = model.predict(images, dim=192).cpu().numpy()
# y: per-epoch labels, groups: per-epoch subject IDs
scores = []
for tr, te in StratifiedGroupKFold(5).split(emb, y, groups):
    clf = LogisticRegression(max_iter=200).fit(emb[tr], y[tr])
    scores.append(balanced_accuracy_score(y[te], clf.predict(emb[te])))
print(f"5-fold BAcc: {np.mean(scores)*100:.1f}%")

Similarity / retrieval

query = model.predict(query_images, dim=48)    # [Nq, 48]
corpus = model.predict(corpus_images, dim=48)  # [Nc, 48]

sim = query @ corpus.T                         # [Nq, Nc]
top5 = sim.topk(5, dim=1).indices              # [Nq, 5]
For large corpora the 48d embeddings work well with FAISS / HNSW.

Batch processing

import torch

all_emb = []
for i in range(0, len(images), 64):
    emb = model.predict(images[i:i+64], dim=192)
    all_emb.append(emb.cpu())

embeddings = torch.cat(all_emb)

Preprocessing details

ne.preprocess matches the exact pipeline used to train the model:
  1. Bandpass filter (1-100 Hz, 4th-order Butterworth)
  2. Notch filter (50 Hz and 100 Hz)
  3. Resample to 250 Hz
  4. Segment into 30-second epochs (pad if shorter)
  5. Average channels into 8 canonical brain regions: Frontal, Central, Temporal Left, Temporal Right, Parietal, Occipital, EOG, ECG
  6. Convert to the temporal matrix representation (224x224 image per channel)
Channels are matched to regions from 10-20 names (Fp1, C3, O2, …). Unmatched channels are silently dropped. Missing regions are zero-filled. This works on any montage — 4-channel frontal, 19-channel 10-20, 64-channel HD-EEG, intracranial.
images = ne.preprocess(
    data,                    # [C, T] or [N, C, T]
    sfreq=256,               # any sampling rate
    channel_names=ch_names,  # 10-20 names
    filter=True,             # set False if already filtered
    stride_seconds=None,     # None = 1.0s sliding window (default)
)
The default is a 1-second sliding window of 30s epochs — 30x denser than non-overlapping, gives smooth trajectories in the embedding space. Pass stride_seconds=30.0 for non-overlapping epochs (typical when each 30s window has its own classification label).