TensorFlow EEG Classification Reproduction¶
This notebook runs the TensorFlow/Keras implementation. It defaults to a small smoke run so the workflow is easy to verify. For the full reproduction, set FAST_DEV_RUN = False and leave DEV_SAMPLE_LIMIT = None.
from pathlib import Path
from eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig
from eegclassify.data import load_processed_arrays
from eegclassify.models.tensorflow import TensorFlowClassifierFactory, TensorFlowGANAugmenter, TensorFlowTrainer
from eegclassify.preprocessing import prepare_splits
Configuration¶
MODEL_NAME = "gan_cnn" is a convenience alias for MODEL_NAME = "cnn" plus USE_GAN_AUGMENTATION = True. GAN augmentation can also be used with lstm or cnn_lstm, and GAN+CNN is the canonical GAN-augmented reference target.
DATA_DIR = Path("../data/processed")
if not DATA_DIR.exists():
DATA_DIR = Path("../data_temp")
FAST_DEV_RUN = True
DEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None
MODEL_NAME = "cnn" # choices: cnn, lstm, cnn_lstm, gan_cnn
USE_GAN_AUGMENTATION = False
# Full reproduction settings:
# FAST_DEV_RUN = False
# DEV_SAMPLE_LIMIT = None
# MODEL_NAME = "cnn"
# USE_GAN_AUGMENTATION = False
if MODEL_NAME == "gan_cnn":
USE_GAN_AUGMENTATION = True
CLASSIFIER_NAME = "cnn"
else:
CLASSIFIER_NAME = MODEL_NAME
preprocess = PreprocessingConfig(seed=1)
split = DataSplitConfig(validation_ratio=0.17, seed=1)
training = TrainingConfig(
fast_dev_run=FAST_DEV_RUN,
use_gan_augmentation=USE_GAN_AUGMENTATION,
gan_samples_per_class=10,
seed=1,
)
model_config = ModelConfig(max_time_step=preprocess.max_time_step)
REFERENCE_TARGETS = {
"cnn": 0.7049,
"cnn_lstm": 0.6095,
"lstm": 0.3962,
"gan_cnn": 0.6823,
}
Load And Prepare Data¶
Preprocessing converts labels 769-772 to 0-3, creates the validation split with seed 1, applies the original max-pooling / averaging / subsampling augmentation, and reshapes arrays to (trials, time, 1, channels).
bundle = load_processed_arrays(DATA_DIR)
prepared = prepare_splits(bundle, preprocess, split)
x_train, y_train = prepared.x_train, prepared.y_train
x_valid, y_valid = prepared.x_valid, prepared.y_valid
x_test, y_test = prepared.x_test, prepared.y_test
if DEV_SAMPLE_LIMIT is not None:
x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]
x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]
x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]
print("train:", x_train.shape, y_train.shape)
print("valid:", x_valid.shape, y_valid.shape)
print("test:", x_test.shape, y_test.shape)
print("labels are one-hot with columns 0-3 for original cues 769-772")
Optional GAN Augmentation¶
When enabled, a conditional GAN is trained on the training split only. Synthetic samples are appended to x_train and y_train; validation and test arrays stay untouched so evaluation remains comparable.
gan_result = None
if USE_GAN_AUGMENTATION:
gan_result = TensorFlowGANAugmenter(training, model_config).augment_training_data(x_train, y_train)
x_train, y_train = gan_result.x_train, gan_result.y_train
print("synthetic samples:", gan_result.x_synthetic.shape[0])
print("augmented train:", x_train.shape, y_train.shape)
else:
print("GAN augmentation disabled")
Build, Train, And Evaluate¶
The classifier architecture is selected after the optional augmentation step. TensorFlow uses categorical cross-entropy, so both one-hot labels and GAN interpolation soft labels are supported.
factory = TensorFlowClassifierFactory(model_config)
model = factory.build(CLASSIFIER_NAME, learning_rate=training.learning_rate)
model.summary()
trainer = TensorFlowTrainer(training)
history = trainer.fit(model, x_train, y_train, x_valid, y_valid)
metrics = trainer.evaluate(model, x_test, y_test)
metrics
Compare With Reference Targets¶
FAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.
target_key = "gan_cnn" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == "cnn" else CLASSIFIER_NAME
target = REFERENCE_TARGETS.get(target_key)
print(f"experiment: {target_key}")
print(f"test accuracy: {metrics['accuracy']:.4f}")
if target is not None:
delta = metrics["accuracy"] - target
print(f"reference target: {target:.4f}")
print(f"delta from target: {delta:+.4f}")
else:
print("No reference target exists for this GAN/classifier combination.")