Skip to content

Generative Adversarial Networks(GANs) for training an imbalance data

Notifications You must be signed in to change notification settings

sinaenjuni/ECOGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ECOGAN is Generative Adversarial Networks(GANs) for generating imbalanced data. A similarity-based distance learning method is applied for imbalance data learning.

Data distribution for imbalance data

Imbalance data refers to data in which the elements (class, object, scale, and etc.) constituting the data are not constant. In this experiment, we learn data consisting of a long tail distribution with an inconsistent number of category information, as shown in Figure (b).

Schematic diagram of discriminators

The following figure is a schematic diagram of discriminator previously proposed for imbalance data learning. BAGAN(a) first pointed out the problems that arise when learning imbalance data through generative models, and proposed a pre-learning method using autoencoder for the first time. Unlike BAGAN, IDA-GAN(b) used a pre-learning method through a variational autoencoder, and proposed a method of learning by dividing the existing one output into two to alleviate the learning contradiction between the generator and the discriminator. EBGAN(c) allows the learning of class information in the pre-learning process by multiplying the latent space with embeddings of class information. Finally, ours(d) proposes a novel structure to enable the application of cosine similarity-based contrast learning methods for imbalance data learning.

Visualization of learning methods

It is a schematic diagram of the learning process of previously proposed Metric learning methods and methods used in a conditional generation model. Our method (f) uses information between all data within batch data for learning, unlike proposed methods, to improve the learning imbalance problem of minority class data.

Experiment result

Experiments were conducted in three aspects to compare performance.

  1. Experiments for performance comparison with existing metric learning methods
  1. Experiments to determine why hinge loss-based loss functions are difficult to learn imbalance data
  1. Experiments for performance comparison with existing pre-learning methods

1. Experiments for Performance Comparison with Existing Distance Learning Methods

For the existing proposed metric learning methods, experiments were conducted with balanced data because they were proposed in a balanced data environment. We also confirm our results with imbalance data to confirm that our proposed method is more useful for imbalance data learning than existing metric learning methods.

The figure below is a visualization of the evaluation metric (FID, IS) measured in the generator learning process. When the above two rows learn balanced data, the following two rows are the results of learning imbalance data. For our method, we can confirm similar or better performance than conventional metric learning. In particular, we show similar performance to the D2D-CE loss function, which improves the misclassification problem that appears in the existing metric learning problem, which can be confirmed to be robust to the misclassification problem, unlike the existing metric learning method. On the other hand, in the case of learning imbalance data, it was confirmed that the performance of existing metric learning methods was no longer improved by mode collapse. This confirms that our method is robust to misclassification problems, especially in imbalance data learning, and that learning problems such as mode collapse do not appear.

Method Data FID(↓) IS score(↑)
2C[20] balance 6.63 9.22
D2D-CE[27] balance 4.71 9.76
ECO(ours) balance 4.88 9.77
2C[20] imbalance 29.04 6.15
D2D-CE[27] imbalance 42.65 5.74
ECO(Ours) imbalance 25.53 6.56

2. Why hinge loss-based loss functions are disadvantageous for unbalanced data learning

The following figure is a visualization of the evaluation indicators measured in the process of learning neural networks of different sizes through the D2D-CE loss function. D2D-CE is an application of hinge loss, which focuses on data learning that is difficult to classify errors in easily classifiable data through methods that do not reflect them in learning. However, in learning unbalanced data, it can be analyzed that mode decay occurs early in learning because minority class data have fewer absolute numbers of learning data, so that the generator targets the unlearned portion of the discriminator before learning the minority class data accurately.

3. Experiments for performance comparison with existing pre-learning methods

Model Data Best step FID(↓) IS score(↑) Pre-trained Sampling
BAGAN[10] FashionMNIST_LT 64000 92.61 2.81 TRUE -
EBGAN[12] FashionMNIST_LT 120000 27.40 2.43 TRUE -
EBGAN[12] FashionMNIST_LT 150000 30.10 2.38 FALSE -
ECOGAN(ours) FashionMNIST_LT 126000 32.91 2.91 - FALSE
ECOGAN(ours) FashionMNIST_LT 120000 20.02 2.63 - TRUE
BAGAN[10] CIFAR10_LT 76000 125.77 2.14 TRUE -
EBGAN[12] CIFAR10_LT 144000 60.11 2.36 TRUE -
EBGAN[12] CIFAR10_LT 150000 68.90 2.29 FALSE -
ECOGAN(ours) CIFAR10_LT 144000 51.71 2.83 - FALSE
ECOGAN(ours) CIFAR10_LT 138000 43.79 2.74 - TRUE
EBGAN[12] Places_LT 150000 136.92 2.57 FALSE -
EBGAN[12] Places_LT 144000 144.04 2.46 TRUE -
ECOGAN(ours) Places_LT 105000 91.55 3.02 - FALSE
ECOGAN(ours) Places_LT 75000 95.43 3.01 - TRUE

Visualization of generated data for evaluating the quality

Usage

Data preprocessing

Modifying code

Data path

data
└── CIFAR10_LT, FashionMNIST_LT or Places_LT
    ├── train
    │   ├── cls0
    │   │   ├── train0.png
    │   │   ├── train1.png
    │   │   └── ...
    │   ├── cls1
    │   └── ...
    └── valid
        ├── cls0
        │   ├── valid0.png
        │   ├── valid1.png
        │   └── ...
        ├── cls1
        └── ...

Training

Modifying code