Skip to content

PyTorch Models#

PyTorch models for EEG classification.

Classes:

Name Description
PyTorchCNN

CNN classifier implemented with PyTorch.

PyTorchLSTM

Stacked bidirectional LSTM classifier implemented with PyTorch.

PyTorchCNNLSTM

Hybrid CNN-LSTM classifier implemented with PyTorch.

PyTorchGenerator

Conditional GAN generator implemented with PyTorch.

PyTorchDiscriminator

Conditional GAN discriminator implemented with PyTorch.

PyTorchGANAugmenter

Train a PyTorch conditional GAN and append generated samples.

PyTorchTrainer

Minimal PyTorch trainer for notebook reproduction.

Functions:

Name Description
build_classifier

Build a named PyTorch classifier.

PyTorchCNN #

PyTorchCNN(config: ModelConfig | None = None)

Bases: Module

CNN classifier implemented with PyTorch.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

PyTorchLSTM #

PyTorchLSTM(config: ModelConfig | None = None)

Bases: Module

Stacked bidirectional LSTM classifier implemented with PyTorch.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

PyTorchCNNLSTM #

PyTorchCNNLSTM(config: ModelConfig | None = None)

Bases: Module

Hybrid CNN-LSTM classifier implemented with PyTorch.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

PyTorchGenerator #

PyTorchGenerator(config: ModelConfig | None = None)

Bases: Module

Conditional GAN generator implemented with PyTorch.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

PyTorchDiscriminator #

PyTorchDiscriminator(config: ModelConfig | None = None)

Bases: Module

Conditional GAN discriminator implemented with PyTorch.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

PyTorchGANAugmenter dataclass #

PyTorchGANAugmenter(
    training: TrainingConfig,
    model_config: ModelConfig | None = None,
    device: str | None = None,
)

Train a PyTorch conditional GAN and append generated samples.

The returned arrays use the same classifier layout as prepare_splits, so the augmentation can be applied before CNN, LSTM, or CNN-LSTM training.

Attributes:

Name Type Description
training TrainingConfig

Training and GAN augmentation configuration.

model_config ModelConfig | None

Optional architecture configuration.

device str | None

Optional device string; defaults to CUDA when available, otherwise CPU.

Methods:

Name Description
augment_training_data

Create GAN-augmented training arrays.

augment_training_data #

augment_training_data(
    x_train: ndarray, y_train: ndarray
) -> GANAugmentationResult

Create GAN-augmented training arrays.

Parameters:

Name Type Description Default
x_train ndarray

Real training inputs shaped (n, time, 1, channels).

required
y_train ndarray

One-hot real training labels.

required

Returns:

Type Description
GANAugmentationResult

Augmented training arrays plus synthetic sample metadata.

PyTorchTrainer dataclass #

PyTorchTrainer(
    training: TrainingConfig, device: str | None = None
)

Minimal PyTorch trainer for notebook reproduction.

Attributes:

Name Type Description
training TrainingConfig

Training configuration controlling epochs and batch size.

device str | None

Optional device string; defaults to CUDA when available, otherwise CPU.

Methods:

Name Description
fit

Train a PyTorch classifier.

evaluate

Evaluate a PyTorch classifier.

fit #

fit(
    model: Module,
    x_train: Any,
    y_train: Any,
    x_valid: Any | None = None,
    y_valid: Any | None = None,
) -> list[dict[str, float]]

Train a PyTorch classifier.

Parameters:

Name Type Description Default
model Module

PyTorch classifier.

required
x_train Any

Training inputs.

required
y_train Any

Training labels as class IDs or one-hot rows.

required
x_valid Any | None

Reserved for API parity; currently unused.

None
y_valid Any | None

Reserved for API parity; currently unused.

None

Returns:

Type Description
list[dict[str, float]]

Per-epoch metrics with loss and accuracy.

evaluate #

evaluate(model: Module, x: Any, y: Any) -> dict[str, float]

Evaluate a PyTorch classifier.

Parameters:

Name Type Description Default
model Module

Trained PyTorch classifier.

required
x Any

Evaluation inputs.

required
y Any

Evaluation labels as class IDs or one-hot rows.

required

Returns:

Type Description
dict[str, float]

Dictionary with loss and accuracy.

build_classifier #

build_classifier(
    name: str, config: ModelConfig | None = None
) -> Module

Build a named PyTorch classifier.

Parameters:

Name Type Description Default
name str

Classifier name: cnn, lstm, or cnn_lstm.

required
config ModelConfig | None

Optional architecture configuration.

None

Returns:

Type Description
Module

PyTorch module for the selected classifier.

Raises:

Type Description
ValueError

If the classifier name is unknown.

ImportError

If PyTorch is not installed.