Tutorial 2: Train a model with checkpointingΒΆ

We introduce two new methods in our trainer. First a method to save the current training state

    def save(self, epoch: int, score: float) -> None:
        state = {
            "action": "save",
            "score": score,
            "maximize_score": True,
            "tracker": self.accs,
            "optimizer": self.optimizer,
            **self.chkpt_options,
        }
        self.trained_model_cb(
            uid=self.uid,
            epoch=epoch,
            model=self.model,
            state=state,

Then a method to restore a the state if training is resumed after an interruption


    def restore(self) -> int:
        loaded_state = {
            "action": "last",
            "maximize_score": True,
            "tracker": self.accs,
            "optimizer": self.optimizer.state_dict(),
        }
        self.trained_model_cb(
            uid=self.uid, epoch=-1, model=self.model, state=loaded_state

Both methods make use of the call_back function that should be passed to the trainer from the TrainedModel table. The important difference is the epoch count that is passed to call_back here. A count of -1 signals the function that we want to retrieve the last checkpoint if there is one. A positive count on the other hand signals that we are in the process of training and want to save the current state as a checkpoint.

Finally, we also have to update the training procedure itself to call self.restore() before the training starts and self.save() after every epoch.

        epoch = loaded_state.get("epoch", -1) + 1
        return epoch

    def train(self) -> Tuple[float, Tuple[List[float], int], Dict]:
        if hasattr(tqdm, "_instances"):
            tqdm._instances.clear()  # To have tqdm output without line-breaks between steps
        torch.manual_seed(self.seed)
        start_epoch = self.restore()
        for epoch in range(start_epoch, self.epochs):
            print(f"Epoch {epoch}")
            predicted_correct = 0
            total = 0
            for x, y in tqdm(self.trainloader):
                p, t = self.train_loop(x, y)
                predicted_correct += p
                total += t
            self.accs.append(100.0 * predicted_correct / total)

Once we have a trainer that supports the checkpointing feature, all we need to do is to switch from TrainedModel table to TrainedModelChkpt.

from nnfabrik.templates.checkpoint import TrainedModelChkptBase, my_checkpoint

Checkpoint = my_checkpoint(nnfabrik)

@nnfabrik.schema
class TrainedModelChkpt(TrainedModelChkptBase):
    table_comment = "My Trained models with checkpointing"
    nnfabrik = nnfabrik
    checkpoint_table = Checkpoint

Now this table can be used just as TrainedModel, i.e. we can simply populate it.

TrainedModelChkpt.populate()

The training state will be saved automatically in the Checkpoint-table after each epoch and can be retrieved from there should the training be interrupted. Thus, after an interruption, we can just clear the error state with:

# delete all jobs in error state:
(schema.jobs & "status='error'").delete()

And then we just restart the training, by calling populate() again.