PyTorch EEG Classification Reproduction¶
This notebook runs the PyTorch implementation with the same preprocessing and experiment defaults as the TensorFlow notebook.
from pathlib import Path
from eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig
from eegclassify.data import load_processed_arrays
from eegclassify.models.pytorch import PyTorchGANAugmenter, PyTorchTrainer, build_classifier
from eegclassify.preprocessing import prepare_splits
Configuration¶
Use USE_GAN_AUGMENTATION = True to add synthetic GAN samples before any classifier. The original parity comparison only applies to CNN with GAN augmentation.
DATA_DIR = Path("../data/processed")
if not DATA_DIR.exists():
DATA_DIR = Path("../data_temp")
FAST_DEV_RUN = True
DEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None
MODEL_NAME = "cnn" # choices: cnn, lstm, cnn_lstm, gan_cnn
USE_GAN_AUGMENTATION = False
# Full reproduction settings:
# FAST_DEV_RUN = False
# DEV_SAMPLE_LIMIT = None
# MODEL_NAME = "cnn"
# USE_GAN_AUGMENTATION = False
if MODEL_NAME == "gan_cnn":
USE_GAN_AUGMENTATION = True
CLASSIFIER_NAME = "cnn"
else:
CLASSIFIER_NAME = MODEL_NAME
preprocess = PreprocessingConfig(seed=1)
split = DataSplitConfig(validation_ratio=0.17, seed=1)
training = TrainingConfig(
fast_dev_run=FAST_DEV_RUN,
use_gan_augmentation=USE_GAN_AUGMENTATION,
gan_samples_per_class=10,
seed=1,
)
model_config = ModelConfig(max_time_step=preprocess.max_time_step)
REFERENCE_TARGETS = {
"cnn": 0.7049,
"cnn_lstm": 0.6095,
"lstm": 0.3962,
"gan_cnn": 0.6823,
}
Load And Prepare Data¶
The package preprocessing applies the default label conversion, seeded validation split, temporal augmentation, and channels-last classifier layout.
bundle = load_processed_arrays(DATA_DIR)
prepared = prepare_splits(bundle, preprocess, split)
x_train, y_train = prepared.x_train, prepared.y_train
x_valid, y_valid = prepared.x_valid, prepared.y_valid
x_test, y_test = prepared.x_test, prepared.y_test
if DEV_SAMPLE_LIMIT is not None:
x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]
x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]
x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]
print("train:", x_train.shape, y_train.shape)
print("valid:", x_valid.shape, y_valid.shape)
print("test:", x_test.shape, y_test.shape)
print("labels are one-hot with columns 0-3 for original cues 769-772")
Optional GAN Augmentation¶
PyTorch uses the same conditional GAN semantics as the TensorFlow implementation. Soft labels generated by interpolation are handled by a distribution cross-entropy loss in the trainer.
gan_result = None
if USE_GAN_AUGMENTATION:
gan_result = PyTorchGANAugmenter(training, model_config).augment_training_data(x_train, y_train)
x_train, y_train = gan_result.x_train, gan_result.y_train
print("synthetic samples:", gan_result.x_synthetic.shape[0])
print("augmented train:", x_train.shape, y_train.shape)
else:
print("GAN augmentation disabled")
Build, Train, And Evaluate¶
The trainer automatically uses CUDA when PyTorch can see a compatible GPU; otherwise it runs on CPU.
model = build_classifier(CLASSIFIER_NAME, model_config)
trainer = PyTorchTrainer(training)
history = trainer.fit(model, x_train, y_train, x_valid, y_valid)
metrics = trainer.evaluate(model, x_test, y_test)
metrics
Compare With Reference Targets¶
FAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.
target_key = "gan_cnn" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == "cnn" else CLASSIFIER_NAME
target = REFERENCE_TARGETS.get(target_key)
print(f"experiment: {target_key}")
print(f"test accuracy: {metrics['accuracy']:.4f}")
if target is not None:
delta = metrics["accuracy"] - target
print(f"reference target: {target:.4f}")
print(f"delta from target: {delta:+.4f}")
else:
print("No reference target exists for this GAN/classifier combination.")