Skip to content

TensorFlow Models#

TensorFlow/Keras models for EEG classification.

Classes:

Name Description
TensorFlowClassifierFactory

Build Keras classifiers for the package model families.

TensorFlowConditionalGAN

Conditional GAN model for synthetic EEG training samples.

TensorFlowGANAugmenter

Train a TensorFlow conditional GAN and append generated samples.

TensorFlowTrainer

Thin TensorFlow trainer wrapper used by notebooks and examples.

Functions:

Name Description
build_generator

Build the conditional GAN generator.

build_discriminator

Build the conditional GAN discriminator.

TensorFlowClassifierFactory #

TensorFlowClassifierFactory(
    model_config: ModelConfig | None = None,
)

Build Keras classifiers for the package model families.

Parameters:

Name Type Description Default
model_config ModelConfig | None

Optional architecture configuration.

None

Methods:

Name Description
build_cnn

Build and compile the CNN classifier.

build_lstm

Build and compile the naive LSTM classifier.

build_cnn_lstm

Build and compile the hybrid CNN-LSTM classifier.

build

Build a named TensorFlow classifier.

build_cnn #

build_cnn(learning_rate: float = 0.001) -> Model

Build and compile the CNN classifier.

Parameters:

Name Type Description Default
learning_rate float

Adam optimizer learning rate.

0.001

Returns:

Type Description
Model

Compiled Keras model.

build_lstm #

build_lstm(learning_rate: float = 0.001) -> Model

Build and compile the naive LSTM classifier.

Parameters:

Name Type Description Default
learning_rate float

Adam optimizer learning rate.

0.001

Returns:

Type Description
Model

Compiled Keras model.

build_cnn_lstm #

build_cnn_lstm(learning_rate: float = 0.001) -> Model

Build and compile the hybrid CNN-LSTM classifier.

Parameters:

Name Type Description Default
learning_rate float

Adam optimizer learning rate.

0.001

Returns:

Type Description
Model

Compiled Keras model.

build #

build(name: str, learning_rate: float = 0.001) -> Model

Build a named TensorFlow classifier.

Parameters:

Name Type Description Default
name str

Classifier name: cnn, lstm, or cnn_lstm.

required
learning_rate float

Adam optimizer learning rate.

0.001

Returns:

Type Description
Model

Compiled Keras model.

Raises:

Type Description
ValueError

If the classifier name is unknown.

TensorFlowConditionalGAN #

TensorFlowConditionalGAN(
    discriminator: Model | None = None,
    generator: Model | None = None,
    config: ModelConfig | None = None,
)

Bases: _TensorFlowGANBase

Conditional GAN model for synthetic EEG training samples.

Parameters:

Name Type Description Default
discriminator Model | None

Optional discriminator model.

None
generator Model | None

Optional generator model.

None
config ModelConfig | None

Optional architecture configuration.

None

TensorFlowGANAugmenter dataclass #

TensorFlowGANAugmenter(
    training: TrainingConfig,
    model_config: ModelConfig | None = None,
)

Train a TensorFlow conditional GAN and append generated samples.

The augmenter is classifier-independent: it receives prepared training arrays and returns augmented arrays that can be passed to CNN, LSTM, or CNN-LSTM classifiers.

Attributes:

Name Type Description
training TrainingConfig

Training and GAN augmentation configuration.

model_config ModelConfig | None

Optional architecture configuration.

Methods:

Name Description
augment_training_data

Create GAN-augmented training arrays.

augment_training_data #

augment_training_data(
    x_train: ndarray, y_train: ndarray
) -> GANAugmentationResult

Create GAN-augmented training arrays.

Parameters:

Name Type Description Default
x_train ndarray

Real training inputs shaped (n, time, 1, channels).

required
y_train ndarray

One-hot real training labels.

required

Returns:

Type Description
GANAugmentationResult

Augmented training arrays plus synthetic sample metadata.

TensorFlowTrainer dataclass #

TensorFlowTrainer(training: TrainingConfig)

Thin TensorFlow trainer wrapper used by notebooks and examples.

Attributes:

Name Type Description
training TrainingConfig

Training configuration controlling batch size, epochs, and dev mode.

Methods:

Name Description
fit

Train a compiled Keras model.

evaluate

Evaluate a trained Keras model.

fit #

fit(
    model: Model,
    x_train: Any,
    y_train: Any,
    x_valid: Any,
    y_valid: Any,
) -> Any

Train a compiled Keras model.

Parameters:

Name Type Description Default
model Model

Compiled classifier.

required
x_train Any

Training inputs.

required
y_train Any

Training one-hot labels.

required
x_valid Any

Validation inputs.

required
y_valid Any

Validation one-hot labels.

required

Returns:

Type Description
Any

Keras History object.

evaluate #

evaluate(
    model: Model, x_test: Any, y_test: Any
) -> dict[str, float]

Evaluate a trained Keras model.

Parameters:

Name Type Description Default
model Model

Trained classifier.

required
x_test Any

Test inputs.

required
y_test Any

Test one-hot labels.

required

Returns:

Type Description
dict[str, float]

Dictionary with loss and accuracy.

build_generator #

build_generator(config: ModelConfig | None = None) -> Model

Build the conditional GAN generator.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

Returns:

Type Description
Model

Keras generator model.

build_discriminator #

build_discriminator(
    config: ModelConfig | None = None,
) -> Model

Build the conditional GAN discriminator.

Parameters:

Name Type Description Default
config ModelConfig | None

Optional architecture configuration.

None

Returns:

Type Description
Model

Keras discriminator model.