Skip to content

JAX Models#

JAX/Flax models for EEG classification.

Classes:

Name Description
FlaxCNN

CNN classifier implemented with Flax.

FlaxLSTM

Stacked bidirectional LSTM classifier implemented with Flax.

FlaxCNNLSTM

Hybrid CNN-LSTM classifier implemented with Flax.

FlaxGenerator

Conditional GAN generator implemented with Flax.

FlaxDiscriminator

Conditional GAN discriminator implemented with Flax.

JAXGANAugmenter

Train a JAX/Flax conditional GAN and append generated samples.

FlaxTrainState

Small training-state container for Flax examples.

JAXTrainer

Minimal JAX/Flax trainer for notebook reproduction.

Functions:

Name Description
build_classifier

Build a named JAX/Flax classifier.

FlaxCNN #

Bases: Module

CNN classifier implemented with Flax.

Attributes:

Name Type Description
config ModelConfig

Architecture configuration.

FlaxLSTM #

Bases: Module

Stacked bidirectional LSTM classifier implemented with Flax.

Attributes:

Name Type Description
config ModelConfig

Architecture configuration.

FlaxCNNLSTM #

Bases: Module

Hybrid CNN-LSTM classifier implemented with Flax.

Attributes:

Name Type Description
config ModelConfig

Architecture configuration.

FlaxGenerator #

Bases: Module

Conditional GAN generator implemented with Flax.

Attributes:

Name Type Description
config ModelConfig

Architecture configuration.

FlaxDiscriminator #

Bases: Module

Conditional GAN discriminator implemented with Flax.

Attributes:

Name Type Description
config ModelConfig

Architecture configuration.

JAXGANAugmenter dataclass #

JAXGANAugmenter(
    training: TrainingConfig,
    model_config: ModelConfig | None = None,
)

Train a JAX/Flax conditional GAN and append generated samples.

The augmenter is independent of the downstream classifier and can be used before CNN, LSTM, or CNN-LSTM training.

Attributes:

Name Type Description
training TrainingConfig

Training and GAN augmentation configuration.

model_config ModelConfig | None

Optional architecture configuration.

Methods:

Name Description
augment_training_data

Create GAN-augmented training arrays.

augment_training_data #

augment_training_data(
    x_train: ndarray, y_train: ndarray
) -> GANAugmentationResult

Create GAN-augmented training arrays.

Parameters:

Name Type Description Default
x_train ndarray

Real training inputs shaped (n, time, 1, channels).

required
y_train ndarray

One-hot real training labels.

required

Returns:

Type Description
GANAugmentationResult

Augmented training arrays plus synthetic sample metadata.

FlaxTrainState dataclass #

FlaxTrainState(
    params: dict,
    batch_stats: dict | None,
    optimizer_state: OptState,
)

Small training-state container for Flax examples.

Attributes:

Name Type Description
params dict

Model parameters.

batch_stats dict | None

Optional batch-normalization statistics.

optimizer_state OptState

Optax optimizer state.

JAXTrainer dataclass #

JAXTrainer(training: TrainingConfig)

Minimal JAX/Flax trainer for notebook reproduction.

Attributes:

Name Type Description
training TrainingConfig

Training configuration controlling epochs and learning rate.

Methods:

Name Description
init_state

Initialize model parameters and optimizer state.

train_step

Run one full-batch optimization step.

fit

Train a Flax classifier.

evaluate

Evaluate a trained Flax classifier.

init_state #

init_state(
    model: Module, rng: Any, sample_batch: Any
) -> FlaxTrainState

Initialize model parameters and optimizer state.

Parameters:

Name Type Description Default
model Module

Flax model.

required
rng Any

JAX random key.

required
sample_batch Any

Example input batch used for shape inference.

required

Returns:

Type Description
FlaxTrainState

Initialized training state.

train_step #

train_step(
    model: Module,
    state: FlaxTrainState,
    x: Any,
    y: Any,
    rng: Any,
) -> tuple[FlaxTrainState, dict[str, float]]

Run one full-batch optimization step.

Parameters:

Name Type Description Default
model Module

Flax model.

required
state FlaxTrainState

Current training state.

required
x Any

Input batch.

required
y Any

One-hot labels.

required
rng Any

JAX random key for dropout.

required

Returns:

Type Description
tuple[FlaxTrainState, dict[str, float]]

Updated state and scalar metrics.

fit #

fit(
    model: Module, x_train: ndarray, y_train: ndarray
) -> list[dict[str, float]]

Train a Flax classifier.

Parameters:

Name Type Description Default
model Module

Flax classifier.

required
x_train ndarray

Training inputs.

required
y_train ndarray

One-hot training labels.

required

Returns:

Type Description
list[dict[str, float]]

Per-epoch metrics with loss and accuracy.

evaluate #

evaluate(
    model: Module, x: ndarray, y: ndarray
) -> dict[str, float]

Evaluate a trained Flax classifier.

Parameters:

Name Type Description Default
model Module

Flax classifier.

required
x ndarray

Evaluation inputs.

required
y ndarray

One-hot or soft-label evaluation targets.

required

Returns:

Type Description
dict[str, float]

Dictionary with loss and accuracy.

Raises:

Type Description
RuntimeError

If fit has not been called and no state is available.

build_classifier #

build_classifier(
    name: str, config: ModelConfig | None = None
) -> Module

Build a named JAX/Flax classifier.

Parameters:

Name Type Description Default
name str

Classifier name: cnn, lstm, or cnn_lstm.

required
config ModelConfig | None

Optional architecture configuration.

None

Returns:

Type Description
Module

Flax module for the selected classifier.

Raises:

Type Description
ValueError

If the classifier name is unknown.

ImportError

If JAX/Flax/Optax are not installed.