PyTorch Models#
PyTorch models for EEG classification.
Classes:
| Name | Description |
|---|---|
PyTorchCNN |
CNN classifier implemented with PyTorch. |
PyTorchLSTM |
Stacked bidirectional LSTM classifier implemented with PyTorch. |
PyTorchCNNLSTM |
Hybrid CNN-LSTM classifier implemented with PyTorch. |
PyTorchGenerator |
Conditional GAN generator implemented with PyTorch. |
PyTorchDiscriminator |
Conditional GAN discriminator implemented with PyTorch. |
PyTorchGANAugmenter |
Train a PyTorch conditional GAN and append generated samples. |
PyTorchTrainer |
Minimal PyTorch trainer for notebook reproduction. |
Functions:
| Name | Description |
|---|---|
build_classifier |
Build a named PyTorch classifier. |
PyTorchCNN
#
PyTorchCNN(config: ModelConfig | None = None)
Bases: Module
CNN classifier implemented with PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
PyTorchLSTM
#
PyTorchLSTM(config: ModelConfig | None = None)
Bases: Module
Stacked bidirectional LSTM classifier implemented with PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
PyTorchCNNLSTM
#
PyTorchCNNLSTM(config: ModelConfig | None = None)
Bases: Module
Hybrid CNN-LSTM classifier implemented with PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
PyTorchGenerator
#
PyTorchGenerator(config: ModelConfig | None = None)
Bases: Module
Conditional GAN generator implemented with PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
PyTorchDiscriminator
#
PyTorchDiscriminator(config: ModelConfig | None = None)
Bases: Module
Conditional GAN discriminator implemented with PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
PyTorchGANAugmenter
dataclass
#
PyTorchGANAugmenter(
training: TrainingConfig,
model_config: ModelConfig | None = None,
device: str | None = None,
)
Train a PyTorch conditional GAN and append generated samples.
The returned arrays use the same classifier layout as prepare_splits, so the
augmentation can be applied before CNN, LSTM, or CNN-LSTM training.
Attributes:
| Name | Type | Description |
|---|---|---|
training |
TrainingConfig
|
Training and GAN augmentation configuration. |
model_config |
ModelConfig | None
|
Optional architecture configuration. |
device |
str | None
|
Optional device string; defaults to CUDA when available, otherwise CPU. |
Methods:
| Name | Description |
|---|---|
augment_training_data |
Create GAN-augmented training arrays. |
augment_training_data
#
augment_training_data(
x_train: ndarray, y_train: ndarray
) -> GANAugmentationResult
Create GAN-augmented training arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_train
|
ndarray
|
Real training inputs shaped |
required |
y_train
|
ndarray
|
One-hot real training labels. |
required |
Returns:
| Type | Description |
|---|---|
GANAugmentationResult
|
Augmented training arrays plus synthetic sample metadata. |
PyTorchTrainer
dataclass
#
PyTorchTrainer(
training: TrainingConfig, device: str | None = None
)
Minimal PyTorch trainer for notebook reproduction.
Attributes:
| Name | Type | Description |
|---|---|---|
training |
TrainingConfig
|
Training configuration controlling epochs and batch size. |
device |
str | None
|
Optional device string; defaults to CUDA when available, otherwise CPU. |
Methods:
| Name | Description |
|---|---|
fit |
Train a PyTorch classifier. |
evaluate |
Evaluate a PyTorch classifier. |
fit
#
fit(
model: Module,
x_train: Any,
y_train: Any,
x_valid: Any | None = None,
y_valid: Any | None = None,
) -> list[dict[str, float]]
Train a PyTorch classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
PyTorch classifier. |
required |
x_train
|
Any
|
Training inputs. |
required |
y_train
|
Any
|
Training labels as class IDs or one-hot rows. |
required |
x_valid
|
Any | None
|
Reserved for API parity; currently unused. |
None
|
y_valid
|
Any | None
|
Reserved for API parity; currently unused. |
None
|
Returns:
| Type | Description |
|---|---|
list[dict[str, float]]
|
Per-epoch metrics with loss and accuracy. |
evaluate
#
build_classifier
#
build_classifier(
name: str, config: ModelConfig | None = None
) -> Module
Build a named PyTorch classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Classifier name: |
required |
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
Returns:
| Type | Description |
|---|---|
Module
|
PyTorch module for the selected classifier. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the classifier name is unknown. |
ImportError
|
If PyTorch is not installed. |