Skip to content

Latest commit

 

History

History
79 lines (55 loc) · 3.16 KB

File metadata and controls

79 lines (55 loc) · 3.16 KB

🧠 MNIST Classification: Simple Gaussian vs. Gaussian Mixture Model (GMM)

Python License Status Dataset


📘 Overview

This project implements and compares two parametric probabilistic classifiers for handwritten digit recognition on the MNIST dataset:

  • 🧩 Simple Gaussian Model (Single Gaussian per class)
  • 🔢 Gaussian Mixture Model (GMM) — multiple Gaussians per class using Expectation-Maximization (EM)

Both models are implemented from scratch in Python, focusing on understanding the math behind probabilistic classifiers rather than using pre-built ML libraries.


⚙️ Experimental Setup

Parameter Description
Dataset MNIST (70,000 samples, 784-dimensional feature space, 10 classes)
Preprocessing Principal Component Analysis (PCA) reduced to 40 components
Data Split 70% Training, 30% Testing
GMM Configuration 10 Gaussian components (K=10) per class
Initialization K-means + EM algorithm for refinement
Evaluation Metrics Accuracy, Error Rate, and ROC Curves (One-vs-Rest)

📊 Performance Summary

Model Accuracy Error Rate
Simple Gaussian Model 95.96% 4.04%
Gaussian Mixture Model (GMM) 96.03% 3.97%

Key Insight:
Although both models perform exceptionally well, the GMM slightly outperforms the Simple Gaussian model by modeling more complex, multi-modal feature distributions within each digit class.


📈 ROC Curve Analysis

The Receiver Operating Characteristic (ROC) curves are computed using a one-vs-rest approach for all 10 digit classes.

Both models demonstrate excellent discriminative performance, with ROC curves approaching the top-left corner — signifying high True Positive Rate (TPR) and low False Positive Rate (FPR).

💡 The GMM’s improved performance arises from its ability to represent diverse handwriting variations (e.g., different styles of writing the same digit) through multiple Gaussian components.


💻 Implementation Details

🔹 Simple Gaussian Model

  • Functions: TrainGaussian(), TestGaussian()
  • Method: Maximum Likelihood Estimation (MLE)
  • Core Function: log_multivariate_pdf() to compute class-conditional log-probabilities

🔹 Gaussian Mixture Model (GMM)

  • Functions: train_gaussian_mixture(), test_gaussian_mixture_model()
  • Method: Expectation-Maximization (EM) Algorithm
    • E-Step: Compute responsibilities
    • M-Step: Update means, covariances, and mixture weights
  • Core Class: multivariate_gaussian for component log-PDFs

🔹 ROC Curve Computation

  • Function: roc_curve()
    • Calculates TPR and FPR over multiple thresholds
    • Generates One-vs-Rest ROC plots

🧩 Dependencies

Install required dependencies:

pip install numpy matplotlib scipy