Step-by-Step Guide: Image Classification with JAX, Flax, and Optax


Introduction

This tutorial guides you through building, training, and evaluating a CNN model using JAX, Flax, and Optax on the MNIST dataset. It covers environment setup, data preprocessing, CNN architecture definition, training, and testing. You’ll discover how JAX’s numerical efficiency, Flax’s flexible neural networks, and Optax’s advanced optimization tools come together to streamline and enhance deep learning workflows. By the end, you’ll understand how these tools contribute to optimizing deep learning models and improving performance.

Image Classification with JAX, Flax, and Optax


Learning Objectives

  • Efficient Neural Network Design: Leverage JAX, Flax, and Optax for seamless integration and powerful model construction.  
  • Dataset Preparation: Master preprocessing and loading datasets using TensorFlow Datasets (TFDS).  
  • Building a CNN: Develop a Convolutional Neural Network (CNN) tailored for accurate image classification.  
  • Training Visualization: Monitor progress with essential metrics like loss and accuracy throughout the training process.  
  • Model Evaluation: Test and apply the trained model on custom images for practical, real-world scenarios.  


JAX, Flax, and Optax: A Powerful Trio

JAX, Flax, and Optax form a powerful trio for modern deep learning workflows. JAX provides high-performance numerical computing with automatic differentiation and seamless GPU/TPU acceleration, making it ideal for efficient model computations. Flax complements JAX with its flexible and modular library, enabling the easy construction of sophisticated neural network architectures. Optax rounds out the trio by offering advanced optimization algorithms, essential for training and fine-tuning deep learning models. Together, these tools integrate seamlessly, simplifying the process of designing, training, and evaluating models while boosting performance for real-world applications.

JAX: The Backbone of Numerical Computing

JAX is a high-performance numerical computing library that combines a NumPy-like syntax with cutting-edge capabilities. It is designed for tasks requiring hardware acceleration and automatic differentiation. Key features include:  

  • Autograd: Enables automatic differentiation for complex functions effortlessly.  
  • JIT Compilation: Enhances execution speed on CPUs, GPUs, or TPUs through just-in-time compilation.  
  • Vectorization: Simplifies batch processing with powerful tools like vmap.  
  • Seamless Hardware Integration: Delivers optimized performance on GPUs and TPUs right out of the box.  

Flax: A Flexible Neural Network Library

Flax is a JAX-based library that makes building neural networks both intuitive and highly customizable. Its key features include:  

  • Stateful Modules: Streamlines the management of parameters and model states.  
  • Compact API: Enables concise and intuitive model definitions using the @nn.compact decorator.  
  • Customizability: Supports a wide range of architectures, from simple to highly complex models.  
  • Seamless JAX Integration: Effortlessly harnesses the full power of JAX’s advanced capabilities.  

Optax: A Comprehensive Optimization Library

Optax streamlines gradient processing and optimization with a versatile set of features:  

  • Optimizers: Includes popular algorithms like SGD, Adam, and RMSProp.  
  • Gradient Processing: Provides tools for clipping, scaling, and normalizing gradients.  
  • Modularity: Facilitates seamless composition of gradient transformations and optimizers.  

Together with JAX and Flax, Optax forms a robust and modular ecosystem for building, training, and fine-tuning deep learning models efficiently.  

JAX and Flax, Optax
 

Getting Started with JAX: Installation and Setup

To dive into JAX and explore its full potential, the first step is to set it up on your system. Here’s a quick guide to installing JAX and getting started with its powerful features, such as automatic differentiation, vectorization, and high-performance computing on CPUs and GPUs. Once installed, you’ll be ready to leverage all that JAX has to offer.

!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets


To install the necessary libraries:

  • jax and jaxlib: Used for numerical computations on GPUs/TPUs.
  • flax: A library for neural networks.
  • optax: Provides optimization functions.
  • tensorflow-datasets: Simplifies the process of loading datasets.

Importing Essential Libraries for JAX, Flax, and Optax

To unlock the full potential of JAX, Flax, and Optax, the initial step is to import the essential libraries into your development environment. This section will walk you through the process of importing these crucial libraries, ensuring that you're fully prepared to execute machine learning tasks efficiently. By importing JAX, Flax, and Optax correctly, you'll be setting up the foundation for building high-performance models that utilize advanced features like GPU/TPU acceleration and automatic differentiation. Let’s begin with the essential imports!

import jax
import jax.numpy as jnp               # JAX NumPy

from flax import linen as nn          # The Linen API
from flax.training import train_state
import optax                          # The Optax gradient processing and optimization library

import numpy as np                    # Ordinary NumPy
import tensorflow_datasets as tfds    # TFDS for MNIST


  • JAX: Enables GPU-accelerated computations.
  • Flax: Used for defining and training Convolutional Neural Networks (CNNs).
  • Optax: Provides optimization algorithms like SGD.
  • TFDS: Facilitates loading datasets such as MNIST.
  • Matplotlib: Used for visualizing training and testing metrics.

Data Preparation: Loading and Preprocessing MNIST

In this section, we will focus on loading and preprocessing the MNIST dataset, a standard dataset widely used in machine learning. The MNIST dataset consists of images of handwritten digits, and by properly preparing it, we ensure the model can effectively learn from the data. We will guide you through importing the dataset, resizing the images, and structuring the data appropriately for both training and evaluation.

def get_datasets():
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  # Split into training/test sets
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  # Convert to floating-points
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
  return train_ds, test_ds
train_ds, test_ds = get_datasets()


We use TFDS to load and preprocess the MNIST dataset:

  • The dataset contains 28×28 grayscale images of digits 0–9.
  • Images are normalized by dividing pixel values by 255, scaling them between 0 and 1. This normalization helps improve convergence during training.

The function returns dictionaries for train_ds and test_ds, each containing the keys 'image' and 'label'.

Building the Convolutional Neural Network (CNN)

Convolutional Neural Networks (CNNs) are the preferred architecture for image classification tasks due to their ability to automatically learn spatial hierarchies from image data through convolutional layers. In this section, we will build a CNN using the JAX + Flax + Optax stack. We’ll walk through defining the convolutional layers, adding activation functions, and constructing the final output layer for digit recognition on the MNIST dataset.

class CNN(nn.Module):

  @nn.compact
  # Provide a constructor to register a new parameter
  # and return its initial value
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1)) # Flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)    # There are 10 classes in MNIST
    return x


  • Convolutional Layers: Use nn.Conv to extract features from the input and apply nn.relu for non-linearity.  
  • Pooling Layers: Reduce the dimensionality of feature maps with `nn.avg_pool`.  
  • Flatten Layer: Transform the feature maps into a 1D vector for the dense layers.  
  • Dense Layers: Include a fully connected layer with 256 neurons for feature learning, followed by an output layer with 10 neurons for MNIST classification.

Model Evaluation: Metrics and Performance Tracking

Once our Convolutional Neural Network (CNN) is properly trained, it’s essential to evaluate its performance using appropriate metrics. Here, we will analyze key observations related to model accuracy and loss, focusing specifically on the training and validation datasets.

def compute_metrics(logits, labels):
  loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics


To evaluate the model's performance, we define the following metrics:  

  • Loss: Computed using `optax.softmax_cross_entropy`, it quantifies the difference between the predicted and actual labels.  
  • Accuracy: Calculates the fraction of correctly predicted labels using `jnp.argmax`.  

The function outputs train_ds and test_ds dictionaries, each containing the keys image and label.  

Training and Evaluation Functions

We implement functions to train the model on the dataset and evaluate its performance. These functions manage the forward pass, compute the loss, perform backpropagation, and monitor the model's accuracy during both the training and validation phases.

@jax.jit
def train_step(state, batch):
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = jnp.mean(optax.softmax_cross_entropy(
        logits=logits,
        labels=jax.nn.one_hot(batch['label'], num_classes=10)))
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, batch['label'])
  return state, metrics
  
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])


Training Step:

  • Loss and Gradients: Computes the loss and gradients with respect to the model parameters using jax.value_and_grad().  
  • Parameter Updates: Updates the model parameters using the optimizer.  
  • Output: Returns the updated state along with metrics to monitor performance.  

Evaluation Step:

  • Model Evaluation: Assesses the model's performance on the provided batch.  
  • Metric Calculation: Computes loss and accuracy using the trained parameters.  

Both functions are JIT-compiled to optimize execution speed and improve performance.

Implementing the Training Loop

We incorporate the training process into an iterative loop that trains the model over multiple epochs. In each iteration, the model is updated using the computed gradients, and performance metrics are monitored to ensure consistent progress toward optimization.

def train_epoch(state, train_ds, batch_size, epoch, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []

  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  training_batch_metrics = jax.device_get(batch_metrics)
  training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

  print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

  return state, training_epoch_metrics


- Determines the total number of training steps from the batch size.  
- Shuffles the dataset and generates batches using `jax.random.permutation`.  
- Executes `train_step` for each batch to update the model.  
- Computes and logs the average training loss and accuracy at the end of each epoch.  

Evaluate the Model

def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)
  metrics = jax.device_get(metrics)
  eval_summary = jax.tree.map(lambda x: x.item(), metrics)
  return eval_summary['loss'], eval_summary['accuracy']

- Calculates the loss and accuracy on the test data using eval_step.  
- Returns the evaluation results, including loss and accuracy.  

Executing the Training and Evaluation Process

This step involves running the training loop while evaluating the model's performance at the end of each epoch. By monitoring both training and validation metrics, we ensure that the model is learning effectively and progressing as expected. Additionally, this process helps assess the model's ability to generalize to data it has not encountered before, ensuring robust performance on unseen inputs.

rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

cnn = CNN()
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']

nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)

state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

# Initialize lists to store metrics for graph visualization
training_losses = []
training_accuracies = []
testing_losses = []
testing_accuracies = []
num_epochs = 10
batch_size = 64

for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))
  # Store metrics for graph visualization
  training_losses.append(train_metrics['loss'])
  training_accuracies.append(train_metrics['accuracy'])
  testing_losses.append(test_loss)
  testing_accuracies.append(test_accuracy)


- RNG Initialization: Initialize a random number generator (`rng`) to ensure reproducibility and enable randomness for data shuffling and parameter initialization.  
- Model Initialization: Define the CNN model and initialize its parameters using a dummy input.  
- Optimizer and Training State:  
  - Use `optax.sgd` as the optimizer with a learning rate of 0.001 and Nesterov momentum of 0.9.  
  - Store the model parameters and optimizer state in TrainState.  
- Training Loop:  
  - Shuffle the training data with a new random key (input_rng).  
  - Train the model for one epoch using `train_epoch`, completing a full pass through the dataset.  
  - Evaluate the model on the test dataset using eval_step.  

- Print Metrics: Log the test loss and accuracy at the end of each epoch.  isualizing Training and Testing Metrics

In this step, we visualize the accuracy and loss metrics for both training and testing over time. This allows us to identify trends, detect potential issues such as overfitting or underfitting, and evaluate the model's overall performance throughout the training process.

import matplotlib.pyplot as plt
# Graph visualization for training/testing loss and accuracy
epochs = range(1, num_epochs + 1)

plt.figure(figsize=(14, 5))

# Plot for Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, training_losses, label='Training Loss', marker='o')
plt.plot(epochs, testing_losses, label='Testing Loss', marker='o')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

Graph visualization for training testing loss and accuracy


Predicting Custom Images

Next, we demonstrate how to use the trained model to make predictions on custom images. This step helps evaluate the model's performance on unseen data and assess its ability to generalize to new, real-world scenarios.

from google.colab import files
from PIL import Image
import numpy as np

# Step 1: Upload the image file
uploaded = files.upload()

# Step 2: Process the uploaded image
def load_and_preprocess_image(file_path):
    img = Image.open(file_path).convert('L')  # Convert to grayscale
    img = img.resize((28, 28))               # Resize to 28x28
    img = np.array(img) / 255.0              # Normalize pixel values to [0, 1]
    img = img.reshape(1, 28, 28, 1)          # Add batch and channel dimensions
    return img

# Step 3: Load and preprocess each uploaded image
for file_name in uploaded.keys():
    test_image = load_and_preprocess_image(file_name)
    print(f"Processed image from {file_name}.")


import jax.numpy as jnp

# Convert to JAX array
test_image_jax = jnp.array(test_image, dtype=jnp.float32)

# Step 4: Use your trained model for predictions
logits = state.apply_fn({'params': state.params}, test_image_jax)
prediction = jnp.argmax(logits, axis=-1)
print(f"Predicted class: {prediction[0]}")



# Display the uploaded image
plt.imshow(test_image[0].squeeze(), cmap='gray')
plt.title(f"Predicted Class: {prediction[0]}")
plt.axis('off')
plt.show()


Uploading Images

  • Uploading Images: Begin by uploading custom handwritten digit images.  
  • File Upload Interface: Use the `files.upload()` function in the Colab environment to open an interface for uploading files.  
  • Image Selection: Select one or more images from the local machine in supported formats such as PNG or JPG.  
  • Access Uploaded Files: Once uploaded, the images become accessible for further processing in the code.  

Preprocessing

After uploading, the model processes the images to match the required input format.

  • Convert to Grayscale: The image is converted to grayscale using `Image.convert('L')`, as MNIST images are single-channel.
  • Resize to 28×28 Pixels: The image is resized to the standard MNIST dimensions using `Image.resize((28, 28))`.
  • Normalize Pixel Values: The pixel values are scaled to the range [0, 1] by dividing by 255.0 to ensure consistent input.
  • Reshape for Model Input: The image is reshaped into a tensor with dimensions [1, 28, 28, 1], which includes the batch size and channel dimensions.
from PIL import Image
import numpy as np

# Function to preprocess the image
def preprocess_image(image_path):
    # Open the image file
    img = Image.open(image_path)
    
    # Convert to grayscale
    img = img.convert('L')
    
    # Resize to 28x28 pixels
    img = img.resize((28, 28))
    
    # Normalize pixel values to the range [0, 1]
    img_array = np.array(img) / 255.0
    
    # Reshape the image for model input (batch size, height, width, channels)
    img_array = img_array.reshape((1, 28, 28, 1))
    
    return img_array

# Example usage:
image_path = 'path_to_your_image.jpg'
processed_image = preprocess_image(image_path)

Prediction

The preprocessed image is converted into a JAX-compatible array (jnp.array), optimizing it for efficient computation. We then pass this array through the trained model using the apply_fn function, which generates the logits (raw output scores for each class). Finally, we use jnp.argmax to identify the index of the highest logit value, corresponding to the class with the highest confidence.

import jax
import jax.numpy as jnp

# Function to make a prediction
def predict(model, preprocessed_image, apply_fn):
    # Convert the preprocessed image to a JAX-compatible array
    image_array = jnp.array(preprocessed_image)
    
    # Pass the image through the model to get the logits (raw scores)
    logits = apply_fn(model, image_array)
    
    # Use jnp.argmax to get the class with the highest logit value
    predicted_class = jnp.argmax(logits, axis=-1)
    
    return predicted_class

# Example usage:
# Assuming `model` is your trained model and `apply_fn` is the function to apply the model
predicted_class = predict(model, processed_image, apply_fn)

print(f"Predicted class: {predicted_class}")

Visualization

  • The processed image is displayed using Matplotlib to offer a visual reference for the user.  
  • The predicted class is shown as the image’s title, making it easy to interpret the results.  
  • This visualization step helps validate the model's predictions and enhances the user-friendliness of the classification process.

import matplotlib.pyplot as plt

# Function to display the image with predicted class
def display_prediction(image, predicted_class):
    # Plot the image
    plt.imshow(image.squeeze(), cmap='gray')  # Remove the extra dimension
    plt.title(f"Predicted Class: {predicted_class}")
    plt.axis('off')  # Hide the axes for a cleaner view
    plt.show()

# Example usage:
# Assuming `processed_image` is the preprocessed image and `predicted_class` is the output
display_prediction(processed_image, predicted_class)

Conclusion

This step-by-step guide highlighted the power and versatility of JAX, Flax, and Optax in building a robust deep learning pipeline for image classification. By utilizing their strengths—such as efficient hardware acceleration, modular design, and advanced optimization techniques—we were able to train a Convolutional Neural Network (CNN) on the dataset with ease. The integration with TensorFlow Datasets (TFDS) streamlined data loading and preprocessing, while visualizing performance metrics provided valuable insights into the model's effectiveness.

The pipeline culminated in testing the model on custom images, demonstrating its practical application. This approach is not only scalable for more complex datasets but also provides a solid foundation for exploring advanced deep learning methods.

Here is the collab link: Click Here.

Key Takeaways

  • JAX, Flax, and Optax offer powerful tools for efficiently building and optimizing deep learning models.  
  • Data preprocessing and augmentation play a crucial role in improving model performance on real-world datasets.  
  • Convolutional Neural Networks (CNNs) are highly effective for image classification tasks, such as MNIST.  
  • Evaluating model performance with the right metrics helps monitor progress and pinpoint areas for improvement.  
  • Visualizing both training and testing metrics provides valuable insights into the model's behavior and development throughout the training process.