JAX Models#
JAX/Flax models for EEG classification.
Classes:
| Name | Description |
|---|---|
FlaxCNN |
CNN classifier implemented with Flax. |
FlaxLSTM |
Stacked bidirectional LSTM classifier implemented with Flax. |
FlaxCNNLSTM |
Hybrid CNN-LSTM classifier implemented with Flax. |
FlaxGenerator |
Conditional GAN generator implemented with Flax. |
FlaxDiscriminator |
Conditional GAN discriminator implemented with Flax. |
JAXGANAugmenter |
Train a JAX/Flax conditional GAN and append generated samples. |
FlaxTrainState |
Small training-state container for Flax examples. |
JAXTrainer |
Minimal JAX/Flax trainer for notebook reproduction. |
Functions:
| Name | Description |
|---|---|
build_classifier |
Build a named JAX/Flax classifier. |
FlaxCNN
#
Bases: Module
CNN classifier implemented with Flax.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ModelConfig
|
Architecture configuration. |
FlaxLSTM
#
Bases: Module
Stacked bidirectional LSTM classifier implemented with Flax.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ModelConfig
|
Architecture configuration. |
FlaxCNNLSTM
#
Bases: Module
Hybrid CNN-LSTM classifier implemented with Flax.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ModelConfig
|
Architecture configuration. |
FlaxGenerator
#
Bases: Module
Conditional GAN generator implemented with Flax.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ModelConfig
|
Architecture configuration. |
FlaxDiscriminator
#
Bases: Module
Conditional GAN discriminator implemented with Flax.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
ModelConfig
|
Architecture configuration. |
JAXGANAugmenter
dataclass
#
JAXGANAugmenter(
training: TrainingConfig,
model_config: ModelConfig | None = None,
)
Train a JAX/Flax conditional GAN and append generated samples.
The augmenter is independent of the downstream classifier and can be used before CNN, LSTM, or CNN-LSTM training.
Attributes:
| Name | Type | Description |
|---|---|---|
training |
TrainingConfig
|
Training and GAN augmentation configuration. |
model_config |
ModelConfig | None
|
Optional architecture configuration. |
Methods:
| Name | Description |
|---|---|
augment_training_data |
Create GAN-augmented training arrays. |
augment_training_data
#
augment_training_data(
x_train: ndarray, y_train: ndarray
) -> GANAugmentationResult
Create GAN-augmented training arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_train
|
ndarray
|
Real training inputs shaped |
required |
y_train
|
ndarray
|
One-hot real training labels. |
required |
Returns:
| Type | Description |
|---|---|
GANAugmentationResult
|
Augmented training arrays plus synthetic sample metadata. |
FlaxTrainState
dataclass
#
JAXTrainer
dataclass
#
JAXTrainer(training: TrainingConfig)
Minimal JAX/Flax trainer for notebook reproduction.
Attributes:
| Name | Type | Description |
|---|---|---|
training |
TrainingConfig
|
Training configuration controlling epochs and learning rate. |
Methods:
| Name | Description |
|---|---|
init_state |
Initialize model parameters and optimizer state. |
train_step |
Run one full-batch optimization step. |
fit |
Train a Flax classifier. |
evaluate |
Evaluate a trained Flax classifier. |
init_state
#
init_state(
model: Module, rng: Any, sample_batch: Any
) -> FlaxTrainState
Initialize model parameters and optimizer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flax model. |
required |
rng
|
Any
|
JAX random key. |
required |
sample_batch
|
Any
|
Example input batch used for shape inference. |
required |
Returns:
| Type | Description |
|---|---|
FlaxTrainState
|
Initialized training state. |
train_step
#
train_step(
model: Module,
state: FlaxTrainState,
x: Any,
y: Any,
rng: Any,
) -> tuple[FlaxTrainState, dict[str, float]]
Run one full-batch optimization step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flax model. |
required |
state
|
FlaxTrainState
|
Current training state. |
required |
x
|
Any
|
Input batch. |
required |
y
|
Any
|
One-hot labels. |
required |
rng
|
Any
|
JAX random key for dropout. |
required |
Returns:
| Type | Description |
|---|---|
tuple[FlaxTrainState, dict[str, float]]
|
Updated state and scalar metrics. |
fit
#
evaluate
#
Evaluate a trained Flax classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flax classifier. |
required |
x
|
ndarray
|
Evaluation inputs. |
required |
y
|
ndarray
|
One-hot or soft-label evaluation targets. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary with |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If |
build_classifier
#
build_classifier(
name: str, config: ModelConfig | None = None
) -> Module
Build a named JAX/Flax classifier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Classifier name: |
required |
config
|
ModelConfig | None
|
Optional architecture configuration. |
None
|
Returns:
| Type | Description |
|---|---|
Module
|
Flax module for the selected classifier. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the classifier name is unknown. |
ImportError
|
If JAX/Flax/Optax are not installed. |