TensorFlow Models#
TensorFlow/Keras models for EEG classification.
Classes:
| Name | Description |
|---|---|
TensorFlowClassifierFactory |
Build Keras classifiers for the package model families. |
TensorFlowConditionalGAN |
Conditional GAN model for synthetic EEG training samples. |
TensorFlowGANAugmenter |
Train a TensorFlow conditional GAN and append generated samples. |
TensorFlowTrainer |
Thin TensorFlow trainer wrapper used by notebooks and examples. |
Functions:
| Name | Description |
|---|---|
build_generator |
Build the conditional GAN generator. |
build_discriminator |
Build the conditional GAN discriminator. |
TensorFlowClassifierFactory
#
TensorFlowClassifierFactory(
model_config: ModelConfig | None = None,
)
Build Keras classifiers for the package model families.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
Methods:
| Name | Description |
|---|---|
build_cnn |
Build and compile the CNN classifier. |
build_lstm |
Build and compile the naive LSTM classifier. |
build_cnn_lstm |
Build and compile the hybrid CNN-LSTM classifier. |
build |
Build a named TensorFlow classifier. |
build_cnn
#
build_cnn(learning_rate: float = 0.001) -> Model
Build and compile the CNN classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
learning_rate
|
float
|
Adam optimizer learning rate. |
0.001
|
Returns:
| Type | Description |
|---|---|
Model
|
Compiled Keras model. |
build_lstm
#
build_lstm(learning_rate: float = 0.001) -> Model
Build and compile the naive LSTM classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
learning_rate
|
float
|
Adam optimizer learning rate. |
0.001
|
Returns:
| Type | Description |
|---|---|
Model
|
Compiled Keras model. |
build_cnn_lstm
#
build_cnn_lstm(learning_rate: float = 0.001) -> Model
Build and compile the hybrid CNN-LSTM classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
learning_rate
|
float
|
Adam optimizer learning rate. |
0.001
|
Returns:
| Type | Description |
|---|---|
Model
|
Compiled Keras model. |
build
#
Build a named TensorFlow classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Classifier name: |
required |
learning_rate
|
float
|
Adam optimizer learning rate. |
0.001
|
Returns:
| Type | Description |
|---|---|
Model
|
Compiled Keras model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the classifier name is unknown. |
TensorFlowConditionalGAN
#
TensorFlowConditionalGAN(
discriminator: Model | None = None,
generator: Model | None = None,
config: ModelConfig | None = None,
)
Bases: _TensorFlowGANBase
Conditional GAN model for synthetic EEG training samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
discriminator
|
Model | None
|
Optional discriminator model. |
None
|
generator
|
Model | None
|
Optional generator model. |
None
|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
TensorFlowGANAugmenter
dataclass
#
TensorFlowGANAugmenter(
training: TrainingConfig,
model_config: ModelConfig | None = None,
)
Train a TensorFlow conditional GAN and append generated samples.
The augmenter is classifier-independent: it receives prepared training arrays and returns augmented arrays that can be passed to CNN, LSTM, or CNN-LSTM classifiers.
Attributes:
| Name | Type | Description |
|---|---|---|
training |
TrainingConfig
|
Training and GAN augmentation configuration. |
model_config |
ModelConfig | None
|
Optional architecture configuration. |
Methods:
| Name | Description |
|---|---|
augment_training_data |
Create GAN-augmented training arrays. |
augment_training_data
#
augment_training_data(
x_train: ndarray, y_train: ndarray
) -> GANAugmentationResult
Create GAN-augmented training arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_train
|
ndarray
|
Real training inputs shaped |
required |
y_train
|
ndarray
|
One-hot real training labels. |
required |
Returns:
| Type | Description |
|---|---|
GANAugmentationResult
|
Augmented training arrays plus synthetic sample metadata. |
TensorFlowTrainer
dataclass
#
TensorFlowTrainer(training: TrainingConfig)
Thin TensorFlow trainer wrapper used by notebooks and examples.
Attributes:
| Name | Type | Description |
|---|---|---|
training |
TrainingConfig
|
Training configuration controlling batch size, epochs, and dev mode. |
Methods:
| Name | Description |
|---|---|
fit |
Train a compiled Keras model. |
evaluate |
Evaluate a trained Keras model. |
fit
#
Train a compiled Keras model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Model
|
Compiled classifier. |
required |
x_train
|
Any
|
Training inputs. |
required |
y_train
|
Any
|
Training one-hot labels. |
required |
x_valid
|
Any
|
Validation inputs. |
required |
y_valid
|
Any
|
Validation one-hot labels. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
Keras |
evaluate
#
build_generator
#
build_generator(config: ModelConfig | None = None) -> Model
Build the conditional GAN generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
Returns:
| Type | Description |
|---|---|
Model
|
Keras generator model. |
build_discriminator
#
build_discriminator(
config: ModelConfig | None = None,
) -> Model
Build the conditional GAN discriminator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
Returns:
| Type | Description |
|---|---|
Model
|
Keras discriminator model. |