Skip to content

ogulcanakca/Improving-Classification-Success-with-GAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 

Repository files navigation

Improving Image Classification Performance Using GAN-Based Data Augmentation

Introduction and Purpose

The main objective of this project is to investigate whether generating synthetic images using Generative Adversarial Networks (GANs) and using these images for data augmentation can improve classification performance in an image classification task with limited data. Specifically, the WGAN-GP (Wasserstein GAN with Gradient Penalty) model and Pseudo-Labeling technique were tested for assigning labels to the generated data.

Dataset

  • Source: The CIFAR-10 dataset was used.
  • Preparation: To simulate a scenario with limited data, 3 classes ('airplane', 'automobile', 'bird') were selected from CIFAR-10, and 1000 random examples were taken from each class to create a limited training set of 3000 examples. The test set contains all examples from these 3 classes.

Methodology

The project includes the following steps:

  1. Baseline Model: A simple Convolutional Neural Network (CNN - SimpleCNN) was trained on the limited training set (3000 examples), and its performance was measured on the test set.
  2. WGAN-GP Training: A WGAN-GP model was trained using the limited training set (3000 examples) to learn this data distribution.
  3. Synthetic Data Generation: The trained WGAN-GP's Generator was used to create new synthetic images (approximately 1500 were generated in this experiment, but they were filtered in the next step).
  4. Pseudo-Labeling: Since previous attempts to assign labels directly or with simple methods to the generated synthetic images had failed, the pseudo-labeling technique was applied at this stage:
    • The synthetic images were fed into the baseline classifier that was trained only on real data.
    • Predictions made by the baseline model with confidence scores above 95% were accepted as "correct" and assigned as the label (pseudo-label) of the synthetic image.
    • Synthetic images with confidence scores below 95% were eliminated. After this filtering, 1255 high-confidence, pseudo-labeled synthetic images were obtained.
  5. Training with Augmented Dataset: The original 3000 real images and 1255 pseudo-labeled synthetic images were combined to create an augmented training set of 4255 examples in total.
  6. Final Classifier: A new model with exactly the same CNN architecture as the baseline model was trained from scratch on this augmented dataset.
  7. Evaluation: The performance of the final classifier trained with augmented data was measured on the same test set used to evaluate the baseline model, and the results were compared.

Results

The performance of the Baseline and WGAN-GP (with Pseudo-Labeling) augmented models on the test set is as follows:

Metric Baseline Model Augmented Model (WGAN-GP + Pseudo-Labeling) Difference
Accuracy 85.90% 84.77% -1.13%
F1 Score (W) 0.8586 0.8474 -0.0112

Classification Reports Summary:

  • Baseline Model:
    • Airplane: Precision=0.84, Recall=0.81, F1=0.83
    • Automobile: Precision=0.89, Recall=0.92, F1=0.90
    • Bird: Precision=0.85, Recall=0.84, F1=0.85
    • Overall Accuracy: 85.90%
  • Augmented Model (WGAN-GP + Pseudo-Labeling):
    • Airplane: Precision=0.82, Recall=0.80, F1=0.81
    • Automobile: Precision=0.88, Recall=0.90, F1=0.89
    • Bird: Precision=0.84, Recall=0.84, F1=0.84
    • Overall Accuracy: 84.77%

(See the relevant Kaggle Notebook for detailed classification reports and confusion matrices.)

Discussion and Conclusions

  • In this experiment, data augmentation with synthetic data generated by WGAN-GP and filtered using high-confidence pseudo-labeling could not surpass the baseline model's performance but came very close (approximately 1.1% lower).
  • The pseudo-labeling technique was found to significantly improve performance compared to simple labeling strategies. This highlights the critical importance of assigning accurate labels to synthetic data.
  • The relatively high baseline performance (around 86%) may have limited the additional benefit that data augmentation could provide for this specific task and model.
  • The quality or diversity of images produced by WGAN-GP, or small errors in pseudo-labeling, might have prevented surpassing the baseline.
  • In conclusion, while GAN-based data augmentation, when used with a careful labeling strategy, is promising, it is not a magical solution that guarantees improved performance in every case. Its success depends on the quality of the generated data, labeling accuracy, task difficulty, dataset characteristics, and the capacity of the classifier model used. This project has practically demonstrated the GAN-based data augmentation process and the challenges that may be encountered.

Environment and Libraries

  • Platform: Kaggle Notebooks (with GPU)
  • Main Libraries: PyTorch, NumPy, Matplotlib, Scikit-learn, Pretty-MIDI (not used in the GAN part, but can be specified for the overall project), OpenCV (should be mentioned if used). WGAN-GP and CNN models were custom implemented using PyTorch.

How to Run

Open the relevant Kaggle Notebook and run the cells in sequence. The dataset (CIFAR-10 and the relevant part of LMD) should be added to the Kaggle environment.

About

An Experiment to Improve Image Classification Performance Using Data Augmentation with GAN

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors