Preparation¶
Dataset¶
The dataset function
def mnist_dataset_fn(seed: int, **config) -> Dict:
"""
Returns data loaders for the given config
Args:
seed: random seed that will make shuffling and other random operations deterministic
Returns:
data_loaders: containing "train", "validation" and "test" data loaders
"""
np.random.seed(seed)
transform_list = [
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)), # (mean,std) of MNIST train set
]
if config.get("apply_augmentation"):
transform_list = [transforms.RandomHorizontalFlip()] + transform_list
transform = transforms.Compose(transform_list)
train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
validation_dataset = datasets.MNIST(
"../data", train=False, download=True, transform=transform
) # for simplicity, we use the test set for validation
test_dataset = datasets.MNIST("../data", train=False, download=True, transform=transform)
batch_size = config.get("batch_size", 64)
return {
"train": DataLoader(train_dataset, batch_size=batch_size, shuffle=config.get("shuffle", True)),
"validation": DataLoader(validation_dataset, batch_size=batch_size),
"test": DataLoader(test_dataset, batch_size=batch_size),
}
Model¶
We define a simple two layer neural network with flexible hidden size h_dim
and ReLU
non-linearity.
class MNISTModel(nn.Module):
def __init__(self, in_dim: int, out_dim: int, h_dim: int = 5) -> None:
super().__init__()
self.fc1 = nn.Linear(in_dim, h_dim)
self.fc2 = nn.Linear(h_dim, out_dim)
self.nl = nn.ReLU()
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.nl(self.fc1(x))
return self.softmax(self.fc2(x))
Initializing this model with for a given config is then done in the mnist_model_fn
.
def mnist_model_fn(dataloaders: Dict, seed: int, **config) -> torch.nn.Module:
"""
Builds a model object for the given config
Args:
data_loaders: a dictionary of data loaders
seed: random seed (e.g. for model initialization)
Returns:
Instance of torch.nn.Module
"""
# get the input and output dimension for the model
first_input, first_output = next(iter(dataloaders["train"]))
in_dim = np.prod(first_input.shape[1:])
out_dim = 10
torch.manual_seed(seed) # for reproducibility (almost)
model = MNISTModel(in_dim, out_dim, h_dim=config.get("h_dim", 5))
return model
Trainer¶
Finally, we define the trainer, which gets the model and dataloaders to execute the actual training.
def __init__(
self,
model: nn.Module,
dataloaders: Dict,
seed: int,
epochs: int = 5,
) -> None:
self.model = model
self.trainloader = dataloaders["train"]
self.seed = seed
self.epochs = epochs
self.loss_fn = nn.NLLLoss()
self.optimizer = optim.Adam(self.model.parameters())
def train_loop(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[float, int]:
# forward:
self.optimizer.zero_grad()
x_flat = x.flatten(1, -1) # treat the images as flat vectors
logits = self.model(x_flat)
loss = self.loss_fn(logits, y)
# backward:
loss.backward()
self.optimizer.step()
# keep track of accuracy:
_, predicted = logits.max(1)
predicted_correct = predicted.eq(y).sum().item()
total = y.shape[0]
return predicted_correct, total
def train(self) -> Tuple[float, Tuple[List[float], int], Dict]:
if hasattr(tqdm, "_instances"):
tqdm._instances.clear() # To have tqdm output without line-breaks between steps
torch.manual_seed(self.seed)
accs = []
for epoch in range(self.epochs):
predicted_correct = 0
total = 0
for x, y in tqdm(self.trainloader):
p, t = self.train_loop(x, y)
predicted_correct += p
total += t
The corresponding trainer function sets up the training, executes it and finally returns output and score.
return accs[-1], (accs, self.epochs), self.model.state_dict()
def mnist_trainer_fn(
model: torch.nn.Module, dataloaders: Dict, seed: int, uid: Dict, cb: Callable, **config
) -> Tuple[float, Any, Dict]:
"""
Args:
model: initialized model to train
data_loaders: containing "train", "validation" and "test" data loaders
seed: random seed
uid: database keys that uniquely identify this trainer call
cb: callback function to ping the database and potentially save the checkpoint
Returns:
score: performance score of the model
output: user specified validation object based on the 'stop function'
model_state: the full state_dict() of the trained model
"""
trainer = MNISTTrainer(model, dataloaders, seed, epochs=config.get("epochs", 2))
out = trainer.train()
return out