GAN Augmentation#
Shared utilities for GAN-based training-data augmentation.
Classes:
| Name | Description |
|---|---|
GANAugmentationResult |
Output from a framework-specific GAN augmentation pass. |
Functions:
| Name | Description |
|---|---|
interpolation_labels |
Create GAN interpolation labels for all ordered class pairs. |
classifier_images_to_gan_sequences |
Normalize classifier inputs into the sequence format expected by GANs. |
gan_sequences_to_classifier_images |
Denormalize generated GAN sequences back to classifier input layout. |
append_synthetic_training_data |
Append generated GAN samples to classifier training arrays. |
GANAugmentationResult
dataclass
#
GANAugmentationResult(
x_train: ndarray,
y_train: ndarray,
x_synthetic: ndarray,
y_synthetic: ndarray,
history: list[dict[str, float]] = list(),
scale: float = 1.0,
metadata: dict[str, Any] = dict(),
)
Output from a framework-specific GAN augmentation pass.
Attributes:
| Name | Type | Description |
|---|---|---|
x_train |
ndarray
|
Training inputs after appending synthetic samples. |
y_train |
ndarray
|
Training labels after appending synthetic labels. |
x_synthetic |
ndarray
|
Synthetic classifier inputs shaped |
y_synthetic |
ndarray
|
Synthetic soft labels shaped |
history |
list[dict[str, float]]
|
Per-epoch GAN losses or framework-specific training history. |
scale |
float
|
Scale factor used to denormalize generated GAN sequences. |
metadata |
dict[str, Any]
|
Extra JSON-serializable details about the augmentation run. |
interpolation_labels
#
Create GAN interpolation labels for all ordered class pairs.
The routine generates samples_per_class points between every ordered pair of
class labels, including same-class pairs. The generated labels are soft labels and
can be used directly with categorical cross-entropy style losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_classes
|
int
|
Number of class labels. |
4
|
samples_per_class
|
int
|
Interpolation points generated for each ordered class pair. |
10
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Float32 label matrix shaped |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
classifier_images_to_gan_sequences
#
Normalize classifier inputs into the sequence format expected by GANs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_train
|
ndarray
|
Classifier inputs shaped |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Tuple of normalized sequences shaped |
float
|
required to map generated values back to the classifier scale. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the input does not use the classifier image layout. |
gan_sequences_to_classifier_images
#
Denormalize generated GAN sequences back to classifier input layout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sequences
|
ndarray
|
GAN outputs shaped |
required |
scale
|
float
|
Scale factor returned by |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Float32 classifier inputs shaped |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
append_synthetic_training_data
#
append_synthetic_training_data(
x_train: ndarray,
y_train: ndarray,
synthetic_sequences: ndarray,
synthetic_labels: ndarray,
scale: float,
history: list[dict[str, float]] | None = None,
metadata: dict[str, Any] | None = None,
) -> GANAugmentationResult
Append generated GAN samples to classifier training arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_train
|
ndarray
|
Real classifier inputs shaped |
required |
y_train
|
ndarray
|
Real one-hot training labels. |
required |
synthetic_sequences
|
ndarray
|
Generated GAN samples shaped |
required |
synthetic_labels
|
ndarray
|
Soft labels aligned with |
required |
scale
|
float
|
Scale factor used to denormalize generated sequences. |
required |
history
|
list[dict[str, float]] | None
|
Optional GAN training history. |
None
|
metadata
|
dict[str, Any] | None
|
Optional run metadata. |
None
|
Returns:
| Type | Description |
|---|---|
GANAugmentationResult
|
|