Skip to main content
import neuroencoder as ne
from neuroencoder import MRL

MRL

model = MRL.from_pretrained(device="cuda")
embeddings = model.predict(images, dim=192)
The MRL model class. Auto-downloads from Neuroencoder/epi-embedding on HuggingFace and caches locally. Pass token=... for explicit authentication.

MRL.from_pretrained

MRL.from_pretrained(repo_id="Neuroencoder/epi-embedding", filename="mrl.pt", device=None, **kwargs)
repo_id
str
default:"\"Neuroencoder/epi-embedding\""
HuggingFace repo identifier.
filename
str
default:"\"mrl.pt\""
Checkpoint filename in the repo.
device
str
Torch device. Defaults to CUDA if available.
**kwargs
Forwarded to huggingface_hub.hf_hub_download (e.g. token="hf_...").

MRL.from_checkpoint

MRL.from_checkpoint(path, device=None)
Load from a local file. Handles raw state dicts and PyTorch Lightning checkpoint formats.

model.embed

model.embed(eeg, sfreq, channel_names=None, dim=192, filter=True, stride_seconds=None)
One-call convenience: raw EEG -> L2-normalized embeddings as a numpy array. Equivalent to model.predict(ne.preprocess(eeg, ...), dim=dim).cpu().numpy().
eeg
np.ndarray | torch.Tensor
required
Raw EEG [C, T] continuous, or [N, C, T] pre-epoched.
sfreq
float
required
Sampling frequency in Hz.
channel_names
list[str]
10-20 names (required if C != 8).
dim
int
default:"192"
One of 768, 384, 192, 48, 16.

model.predict

model.predict(x, dim=192)
Returns [N, dim] L2-normalized embeddings on the model’s device. Runs in torch.no_grad(). Input is auto-moved to the model device.
x
torch.Tensor
required
[B, 8, 224, 224] from ne.preprocess.
dim
int
default:"192"
One of 768, 384, 192, 48, 16.

model.forward

model(x, dim=192)
Forward pass with gradients (for fine-tuning). Output is not L2-normalized.

Attributes

model.encoder    # underlying encoder
model.projector  # MRL projector head

ne.preprocess

ne.preprocess(data, sfreq, channel_names=None, filter=True, stride_seconds=None, device=None)
Raw EEG to temporal matrix images. Handles any channel count (averages into 8 brain regions, zero-fills missing). Returns [N, 8, 224, 224].
data
np.ndarray | torch.Tensor
required
[C, T] continuous, or [N, C, T] pre-epoched.
sfreq
float
required
Sampling frequency in Hz.
channel_names
list[str]
10-20 electrode names.
filter
bool
default:"true"
Apply 1-100 Hz bandpass and 50/100 Hz notch.
stride_seconds
float
default:"1.0"
Hop between successive 30s epochs. Default is 1.0s (continuous sliding window). Pass 30.0 for non-overlapping epochs — typical for classification with epoch-level labels.
device
str
Torch device.

ne.explore

ne.explore(embeddings, filename=None, method="umap", show_charts=False, show_table=False)
Interactive Apple Embedding Atlas widget. Defaults: time-coloring, side panels hidden.
embeddings
np.ndarray | torch.Tensor
required
[N, D].
filename
str | list[str]
Source-file label for each point. String applies to all, or one entry per point.
method
str
default:"\"umap\""
"umap", "tsne", or "pca".
n_neighbors
int
default:"15"
UMAP / KNN neighbors.
point_size
float
default:"3.0"
Scatter point size.
epoch_seconds
float
default:"30.0"
Seconds per embedding (used to compute the time axis).
show_charts
bool
default:"false"
Show right-side charts panel.
show_table
bool
default:"false"
Show bottom data-table panel.

ne.serve

ne.serve(embeddings, filename=None, method="umap", host="127.0.0.1", port=5055, open_browser=True)
Standalone Atlas web server (no Jupyter). Same defaults as ne.explore. Blocks until Ctrl-C.
host
str
default:"\"127.0.0.1\""
Bind address.
port
int
default:"5055"
Server port.
open_browser
bool
default:"true"
Open the URL in the default browser.
(Other args identical to ne.explore.)

ne.plot

ne.plot(embeddings, color=None, **kwargs)
Static matplotlib UMAP. Default coloring: time (viridis).
embeddings
np.ndarray | torch.Tensor
required
[N, D].
color
str | array
"time" (default), "cluster", or array of values.

Constants

from neuroencoder import MRL_DIMS
# [768, 384, 192, 48, 16]