JAX/Flax EEG Classification Reproduction¶
This notebook runs the JAX/Flax implementation. It follows the same data preparation and configuration pattern as the TensorFlow and PyTorch examples.
from pathlib import Path
from eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig
from eegclassify.data import load_processed_arrays
from eegclassify.models.jax import JAXGANAugmenter, JAXTrainer, build_classifier
from eegclassify.preprocessing import prepare_splits
Configuration¶
JAX compiles model operations lazily, so the first cell that trains or evaluates a model may take longer than later cells. GAN augmentation is still a training-data step and can be reused with any classifier.
DATA_DIR = Path("../data/processed")
if not DATA_DIR.exists():
DATA_DIR = Path("../data_temp")
FAST_DEV_RUN = True
DEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None
MODEL_NAME = "cnn" # choices: cnn, lstm, cnn_lstm, gan_cnn
USE_GAN_AUGMENTATION = False
# Full reproduction settings:
# FAST_DEV_RUN = False
# DEV_SAMPLE_LIMIT = None
# MODEL_NAME = "cnn"
# USE_GAN_AUGMENTATION = False
if MODEL_NAME == "gan_cnn":
USE_GAN_AUGMENTATION = True
CLASSIFIER_NAME = "cnn"
else:
CLASSIFIER_NAME = MODEL_NAME
preprocess = PreprocessingConfig(seed=1)
split = DataSplitConfig(validation_ratio=0.17, seed=1)
training = TrainingConfig(
fast_dev_run=FAST_DEV_RUN,
use_gan_augmentation=USE_GAN_AUGMENTATION,
gan_samples_per_class=10,
seed=1,
)
model_config = ModelConfig(max_time_step=preprocess.max_time_step)
REFERENCE_TARGETS = {
"cnn": 0.7049,
"cnn_lstm": 0.6095,
"lstm": 0.3962,
"gan_cnn": 0.6823,
}
Load And Prepare Data¶
The prepared arrays use one-hot labels. GAN interpolation labels are soft distributions and are supported by the JAX trainer's soft-label cross-entropy.
bundle = load_processed_arrays(DATA_DIR)
prepared = prepare_splits(bundle, preprocess, split)
x_train, y_train = prepared.x_train, prepared.y_train
x_valid, y_valid = prepared.x_valid, prepared.y_valid
x_test, y_test = prepared.x_test, prepared.y_test
if DEV_SAMPLE_LIMIT is not None:
x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]
x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]
x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]
print("train:", x_train.shape, y_train.shape)
print("valid:", x_valid.shape, y_valid.shape)
print("test:", x_test.shape, y_test.shape)
print("labels are one-hot with columns 0-3 for original cues 769-772")
Optional GAN Augmentation¶
Synthetic samples are generated from the training split only. Validation and test data remain real data, keeping validation and test metrics tied to real held-out data.
gan_result = None
if USE_GAN_AUGMENTATION:
gan_result = JAXGANAugmenter(training, model_config).augment_training_data(x_train, y_train)
x_train, y_train = gan_result.x_train, gan_result.y_train
print("synthetic samples:", gan_result.x_synthetic.shape[0])
print("augmented train:", x_train.shape, y_train.shape)
else:
print("GAN augmentation disabled")
Build, Train, And Evaluate¶
The JAX trainer keeps the final train state internally so evaluation uses the trained parameters and batch-normalization statistics.
model = build_classifier(CLASSIFIER_NAME, model_config)
trainer = JAXTrainer(training)
history = trainer.fit(model, x_train, y_train)
metrics = trainer.evaluate(model, x_test, y_test)
metrics
Compare With Reference Targets¶
FAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.
target_key = "gan_cnn" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == "cnn" else CLASSIFIER_NAME
target = REFERENCE_TARGETS.get(target_key)
print(f"experiment: {target_key}")
print(f"test accuracy: {metrics['accuracy']:.4f}")
if target is not None:
delta = metrics["accuracy"] - target
print(f"reference target: {target:.4f}")
print(f"delta from target: {delta:+.4f}")
else:
print("No reference target exists for this GAN/classifier combination.")