Skip to content

GAN Augmentation#

Shared utilities for GAN-based training-data augmentation.

Classes:

Name Description
GANAugmentationResult

Output from a framework-specific GAN augmentation pass.

Functions:

Name Description
interpolation_labels

Create GAN interpolation labels for all ordered class pairs.

classifier_images_to_gan_sequences

Normalize classifier inputs into the sequence format expected by GANs.

gan_sequences_to_classifier_images

Denormalize generated GAN sequences back to classifier input layout.

append_synthetic_training_data

Append generated GAN samples to classifier training arrays.

GANAugmentationResult dataclass #

GANAugmentationResult(
    x_train: ndarray,
    y_train: ndarray,
    x_synthetic: ndarray,
    y_synthetic: ndarray,
    history: list[dict[str, float]] = list(),
    scale: float = 1.0,
    metadata: dict[str, Any] = dict(),
)

Output from a framework-specific GAN augmentation pass.

Attributes:

Name Type Description
x_train ndarray

Training inputs after appending synthetic samples.

y_train ndarray

Training labels after appending synthetic labels.

x_synthetic ndarray

Synthetic classifier inputs shaped (n, time, 1, channels).

y_synthetic ndarray

Synthetic soft labels shaped (n, classes).

history list[dict[str, float]]

Per-epoch GAN losses or framework-specific training history.

scale float

Scale factor used to denormalize generated GAN sequences.

metadata dict[str, Any]

Extra JSON-serializable details about the augmentation run.

interpolation_labels #

interpolation_labels(
    n_classes: int = 4, samples_per_class: int = 10
) -> ndarray

Create GAN interpolation labels for all ordered class pairs.

The routine generates samples_per_class points between every ordered pair of class labels, including same-class pairs. The generated labels are soft labels and can be used directly with categorical cross-entropy style losses.

Parameters:

Name Type Description Default
n_classes int

Number of class labels.

4
samples_per_class int

Interpolation points generated for each ordered class pair.

10

Returns:

Type Description
ndarray

Float32 label matrix shaped (n_classes * n_classes * samples_per_class, n_classes).

Raises:

Type Description
ValueError

If n_classes or samples_per_class is not positive.

classifier_images_to_gan_sequences #

classifier_images_to_gan_sequences(
    x_train: ndarray,
) -> tuple[ndarray, float]

Normalize classifier inputs into the sequence format expected by GANs.

Parameters:

Name Type Description Default
x_train ndarray

Classifier inputs shaped (n, time, 1, channels).

required

Returns:

Type Description
ndarray

Tuple of normalized sequences shaped (n, time, channels) and the scale factor

float

required to map generated values back to the classifier scale.

Raises:

Type Description
ValueError

If the input does not use the classifier image layout.

gan_sequences_to_classifier_images #

gan_sequences_to_classifier_images(
    sequences: ndarray, scale: float
) -> ndarray

Denormalize generated GAN sequences back to classifier input layout.

Parameters:

Name Type Description Default
sequences ndarray

GAN outputs shaped (n, time, channels) in the normalized [0, 1] range.

required
scale float

Scale factor returned by classifier_images_to_gan_sequences.

required

Returns:

Type Description
ndarray

Float32 classifier inputs shaped (n, time, 1, channels).

Raises:

Type Description
ValueError

If sequences is not three-dimensional.

append_synthetic_training_data #

append_synthetic_training_data(
    x_train: ndarray,
    y_train: ndarray,
    synthetic_sequences: ndarray,
    synthetic_labels: ndarray,
    scale: float,
    history: list[dict[str, float]] | None = None,
    metadata: dict[str, Any] | None = None,
) -> GANAugmentationResult

Append generated GAN samples to classifier training arrays.

Parameters:

Name Type Description Default
x_train ndarray

Real classifier inputs shaped (n, time, 1, channels).

required
y_train ndarray

Real one-hot training labels.

required
synthetic_sequences ndarray

Generated GAN samples shaped (m, time, channels).

required
synthetic_labels ndarray

Soft labels aligned with synthetic_sequences.

required
scale float

Scale factor used to denormalize generated sequences.

required
history list[dict[str, float]] | None

Optional GAN training history.

None
metadata dict[str, Any] | None

Optional run metadata.

None

Returns:

Type Description
GANAugmentationResult

GANAugmentationResult with real and synthetic samples concatenated.