End-to-end walkthrough for Neural Keyword Spotting (KWS) on the LibriBrain MEG corpus: load data, frame the task, train a compact baseline, and evaluate with precision–recall metrics tailored to extreme class imbalance.
Note: This tutorial is released in conjunction with our DBM workshop paper “Elementary, My Dear Watson: Non-Invasive Neural Keyword Spotting in the LibriBrain Dataset”
. The tutorial provides a comprehensive introduction as well as a hands-on, pedagogical walkthrough of the methods and concepts presented in the paper.
Neural Keyword Spotting (KWS) from brain signals presents a promising direction for non-invasive brain–computer interfaces (BCIs), with potential applications in assistive communication technologies for individuals with speech impairments. While invasive BCIs have achieved remarkable success in speech decoding
This tutorial demonstrates how to build and evaluate a neural keyword spotting system using the LibriBrain dataset
Full speech decoding from non-invasive brain signals remains an open challenge. However, keyword spotting—detecting specific words of interest—offers a more tractable goal that could still enable meaningful communication. Even detecting a single keyword reliably (a “1-bit channel”) could significantly improve quality of life for individuals with severe communication disabilities, allowing them to:
Keyword spotting from MEG presents two fundamental challenges:
Extreme Class Imbalance: Even short, common words like “the” represent only ~5.5% of all words in naturalistic speech. Target keywords like “Watson” appear in just 0.12% of word windows, creating a severe imbalance.
Low Signal-to-Noise Ratio: Unlike invasive recordings with electrode arrays placed directly on the cortex, non-invasive MEG/EEG sensors sit outside the skull, capturing attenuated and spatially blurred neural signals mixed with physiological and environmental noise.
These challenges require specialized techniques, which we cover in this tutorial.
The LibriBrain dataset
We frame keyword detection as event-referenced binary classification:
This differs from continuous detection by:
Data Splits: We use multiple training sessions and dynamically select validation/test sessions based on keyword prevalence to ensure sufficient positive examples in held-out sets.
The tutorials baseline model addresses the challenges through three components:
Note: The notebook first demonstrates individual components with simplified examples (e.g.,
ConvTrunkwith stride-2), then presents the full training architecture below.
The model begins with a Conv1D layer projecting the 306 MEG channels to 128 dimensions, followed by a residual block
self.trunk = nn.Sequential(
nn.Conv1d(306, 128, 7, 1, padding='same'),
ResNetBlock1D(128),
nn.ELU(),
nn.Conv1d(128, 128, 50, 25, 0), # stride-25 downsampling
nn.ELU(),
nn.Conv1d(128, 128, 7, 1, padding='same'),
nn.ELU(),
)
The trunk output is projected to 512 dimensions before splitting into two parallel 1×1 convolution heads: one producing per-timepoint logits, the other producing attention scores. The attention mechanism
self.head = nn.Sequential(nn.Conv1d(128, 512, 4, 1, 0), nn.ReLU(), nn.Dropout(0.5))
self.logits_t = nn.Conv1d(512, 1, 1, 1, 0)
self.attn_t = nn.Conv1d(512, 1, 1, 1, 0)
def forward(self, x):
h = self.head(self.trunk(x))
logit_t = self.logits_t(h)
attn = torch.softmax(self.attn_t(h), dim=-1)
return (logit_t * attn).sum(dim=-1).squeeze(1)
Standard cross-entropy fails under extreme class imbalance. We employ two complementary losses:
Focal Loss
Pairwise Ranking Loss
def pairwise_logistic_loss(scores, targets):
pos_idx = (targets == 1).nonzero()
neg_idx = (targets == 0).nonzero()
# Sample pairs and penalize inversions
margins = scores[pos_idx] - scores[sampled_neg_idx]
return torch.log1p(torch.exp(-margins)).mean()
Balanced Sampling: We construct training batches with ~10% positive rate (vs. natural <1%) by:
This ensures gradients aren’t starved by all-negative batches while keeping evaluation on natural class priors for realistic metrics.
Preprocessing: The dataset applies per-channel z-score normalization and clips outliers beyond ±10σ before feeding data to the model.
Data Augmentation
Regularization: Dropout (p=0.5), weight decay
Traditional accuracy is meaningless under extreme imbalance (always predicting “no keyword” yields >99% accuracy). We employ metrics that reflect real-world BCI deployment:
Area Under Precision-Recall Curve (AUPRC)
Precision-Recall Trade-off:
False Alarms per Hour (FA/h):
(False Positives / total_seconds) × 3600
Operating Point Selection: Choose threshold on validation to meet FA/h or precision targets; report test results at that fixed threshold.
pnpl library (~50 GB for the full set, ~5GB for the default subset)The tutorial is designed to run on consumer hardware by training on a subset of data. To scale to the full 50+ hours of data, increase training sessions in the configuration and use a higher-tier GPU (V100/A100).
By working through this tutorial, you will:
The accompanying Jupyter notebook provides a complete, executable walkthrough:
Access the full interactive tutorial:
Links:
Requirements: A Google account for Colab, or local Jupyter Notebook install with Python 3.10+
Besides the accompanying workshop paper
For attribution in academic contexts, please cite this work as
PLACEHOLDER FOR ACADEMIC ATTRIBUTION
BibTeX citation
PLACEHOLDER FOR BIBTEX