Train a PyTorch Model on Fashion MNIST: Jupyter Notebook

This document describes how to run a job on your cluster that distributes the training workload across multiple workers using Ray's distributed computing capabilities. This allows for parallelizing the training process and potentially reducing the overall training time. In the instructions below, we run Train a PyTorch Model on Fashion MNIST job using Jupyter Notebook.

For more information about Jupyter Notebook, see their documentation.

Table of Contents

Steps to run a test job in Jupyter Notebook:

  1. After your cluster deployment is complete, go to View Cluster.

  2. On the cluster detail page, copy the IDE Password and click Jupyter Notebook.

  3. Enter your IDE Password you copied in the Jupyter password field.

  4. Click File to create a new Python Notebook.

  5. In the New dropdown, select Notebook. It launches a new tab.

  6. A new notebook will open in a new browser tab with a prompt to select a kernel. Choose Python 3 for this example, then click Select.

  7. Enter the code sample below into a cell and click Run.

    import os
    from typing import Dict
    
    import torch
    from filelock import FileLock
    from torch import nn
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    from torchvision.transforms import Normalize, ToTensor
    from tqdm import tqdm
    
    import ray.train
    from ray.train import ScalingConfig
    from ray.train.torch import TorchTrainer
    
    
    def get_dataloaders(batch_size):
        transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    
        with FileLock(os.path.expanduser("~/data.lock")):
            training_data = datasets.FashionMNIST(
                root="~/data",
                train=True,
                download=True,
                transform=transform,
            )
    
            test_data = datasets.FashionMNIST(
                root="~/data",
                train=False,
                download=True,
                transform=transform,
            )
    
        train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
        test_dataloader = DataLoader(test_data, batch_size=batch_size)
    
        return train_dataloader, test_dataloader
    
    
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28 * 28, 512),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Dropout(0.25),
                nn.Linear(512, 10),
                nn.ReLU(),
            )
    
        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits
    
    
    def train_func_per_worker(config: Dict):
        lr = config["lr"]
        epochs = config["epochs"]
        batch_size = config["batch_size_per_worker"]
    
        train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)
    
        train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
        test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)
    
        model = NeuralNetwork()
    
        model = ray.train.torch.prepare_model(model)
    
        loss_fn = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
        # Model training loop
        for epoch in range(epochs):
            if ray.train.get_context().get_world_size() > 1:
                train_dataloader.sampler.set_epoch(epoch)
    
            model.train()
            for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
                pred = model(X)
                loss = loss_fn(pred, y)
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            model.eval()
            test_loss, num_correct, num_total = 0, 0, 0
            with torch.no_grad():
                for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
                    pred = model(X)
                    loss = loss_fn(pred, y)
    
                    test_loss += loss.item()
                    num_total += y.shape[0]
                    num_correct += (pred.argmax(1) == y).sum().item()
    
            test_loss /= len(test_dataloader)
            accuracy = num_correct / num_total
    
            ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})
    
    
    def train_fashion_mnist(num_workers=2, use_gpu=False):
        global_batch_size = 32
    
        train_config = {
            "lr": 1e-3,
            "epochs": 10,
            "batch_size_per_worker": global_batch_size // num_workers,
        }
    
        # Configure computation resources
        scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
    
        # Initialize a Ray TorchTrainer
        trainer = TorchTrainer(
            train_loop_per_worker=train_func_per_worker,
            train_loop_config=train_config,
            scaling_config=scaling_config,
        )
    
        result = trainer.fit()
        print(f"Training result: {result}")
    
  8. Enter the Python command below in a new cell to run the training model script. Then click Run.

    train_fashion_mnist(num_workers=2, use_gpu=True)
    

    📘

    Note, by default, 2 CPUs and a GPU are set for this command. Make sure that your hardware has enough CPU and GPU available, increase or reduce the allocation if needed.

  9. If you scroll to the bottom of the output, you will see the training result.

    Training result: Result(
      metrics={'loss': 0.3572742183404133, 'accuracy': 0.8728},
      path='/home/ray/ray_results/TorchTrainer_2024-05-17_18-55-55/TorchTrainer_c3725_00000_0_2024-05-17_18-55-55',
      filesystem='local',
      checkpoint=None
    )
    

Congratulations on Successfully Training Your First Model

You can now track your model's progress using the Ray Dashboard. The dashboard provides detailed insights into your cluster, including cluster utilization, status, autoscaler activity, resource states, and more.

  1. Return to your cluster. On the cluster detail page, copy the IDE Password and click Ray Dashboard.


  2. In the password field, enter your password. Click View All Jobs. Here, you can see that your job is running.


  3. You can also check this in io.net by going to Clusters > select your cluster > click an IO Worker > Jobs.

Troubleshooting Model Training

  1. If you see an error after running the example code that matches the one below:

    2025-05-15 01:39:02,503	INFO util.py:154 -- Outdated packages:
      ipywidgets==7.7.2 found, needs ipywidgets>=8
    Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
    2025-05-15 01:39:03,181	INFO util.py:154 -- Outdated packages:
      ipywidgets==7.7.2 found, needs ipywidgets>=8
    Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
    2025-05-15 01:39:03,219	INFO util.py:154 -- Outdated packages:
      ipywidgets==7.7.2 found, needs ipywidgets>=8
    Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
    
  2. Copy the update command for outdated packages, paste it into a new cell, and click the Run button to install the updates:

    pip install -U ipywidgets
    
  3. In the toolbar, click Kernel and select Restart the kernel from the dropdown. This updates the packages.

  4. Then paste the command again and run it to execute the script.