Preparation

Dataset

The dataset function

def mnist_dataset_fn(seed: int, **config) -> Dict:
    """
    Returns data loaders for the given config
    Args:
        seed: random seed that will make shuffling and other random operations deterministic
    Returns:
        data_loaders: containing "train", "validation" and "test" data loaders
    """
    np.random.seed(seed)

    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),  # (mean,std) of MNIST train set
    ]
    if config.get("apply_augmentation"):
        transform_list = [transforms.RandomHorizontalFlip()] + transform_list
    transform = transforms.Compose(transform_list)
    train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
    validation_dataset = datasets.MNIST(
        "../data", train=False, download=True, transform=transform
    )  # for simplicity, we use the test set for validation
    test_dataset = datasets.MNIST("../data", train=False, download=True, transform=transform)
    batch_size = config.get("batch_size", 64)
    return {
        "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=config.get("shuffle", True)),
        "validation": DataLoader(validation_dataset, batch_size=batch_size),
        "test": DataLoader(test_dataset, batch_size=batch_size),
    }

Model

We define a simple two layer neural network with flexible hidden size h_dim and ReLU non-linearity.

class MNISTModel(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, h_dim: int = 5) -> None:
        super().__init__()

        self.fc1 = nn.Linear(in_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, out_dim)
        self.nl = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.nl(self.fc1(x))
        return self.softmax(self.fc2(x))

Initializing this model with for a given config is then done in the mnist_model_fn.

def mnist_model_fn(dataloaders: Dict, seed: int, **config) -> torch.nn.Module:
    """
    Builds a model object for the given config
    Args:
        data_loaders: a dictionary of data loaders
        seed: random seed (e.g. for model initialization)
    Returns:
        Instance of torch.nn.Module
    """
    # get the input and output dimension for the model
    first_input, first_output = next(iter(dataloaders["train"]))
    in_dim = np.prod(first_input.shape[1:])
    out_dim = 10

    torch.manual_seed(seed)  # for reproducibility (almost)
    model = MNISTModel(in_dim, out_dim, h_dim=config.get("h_dim", 5))

    return model

Trainer

Finally, we define the trainer, which gets the model and dataloaders to execute the actual training.

    def __init__(
        self,
        model: nn.Module,
        dataloaders: Dict,
        seed: int,
        epochs: int = 5,
    ) -> None:

        self.model = model
        self.trainloader = dataloaders["train"]
        self.seed = seed
        self.epochs = epochs
        self.loss_fn = nn.NLLLoss()
        self.optimizer = optim.Adam(self.model.parameters())

    def train_loop(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[float, int]:
        # forward:
        self.optimizer.zero_grad()
        x_flat = x.flatten(1, -1)  # treat the images as flat vectors
        logits = self.model(x_flat)
        loss = self.loss_fn(logits, y)
        # backward:
        loss.backward()
        self.optimizer.step()
        # keep track of accuracy:
        _, predicted = logits.max(1)
        predicted_correct = predicted.eq(y).sum().item()
        total = y.shape[0]
        return predicted_correct, total

    def train(self) -> Tuple[float, Tuple[List[float], int], Dict]:
        if hasattr(tqdm, "_instances"):
            tqdm._instances.clear()  # To have tqdm output without line-breaks between steps
        torch.manual_seed(self.seed)
        accs = []
        for epoch in range(self.epochs):
            predicted_correct = 0
            total = 0
            for x, y in tqdm(self.trainloader):
                p, t = self.train_loop(x, y)
                predicted_correct += p
                total += t

The corresponding trainer function sets up the training, executes it and finally returns output and score.

        return accs[-1], (accs, self.epochs), self.model.state_dict()


def mnist_trainer_fn(
    model: torch.nn.Module, dataloaders: Dict, seed: int, uid: Dict, cb: Callable, **config
) -> Tuple[float, Any, Dict]:
    """
    Args:
        model: initialized model to train
        data_loaders: containing "train", "validation" and "test" data loaders
        seed: random seed
        uid: database keys that uniquely identify this trainer call
        cb: callback function to ping the database and potentially save the checkpoint
    Returns:
        score: performance score of the model
        output: user specified validation object based on the 'stop function'
        model_state: the full state_dict() of the trained model
    """
    trainer = MNISTTrainer(model, dataloaders, seed, epochs=config.get("epochs", 2))
    out = trainer.train()

    return out