Tutorial 2: Train a model with checkpointingΒΆ
We introduce two new methods in our trainer. First a method to save the current training state
def save(self, epoch: int, score: float) -> None:
state = {
"action": "save",
"score": score,
"maximize_score": True,
"tracker": self.accs,
"optimizer": self.optimizer,
**self.chkpt_options,
}
self.trained_model_cb(
uid=self.uid,
epoch=epoch,
model=self.model,
state=state,
Then a method to restore a the state if training is resumed after an interruption
def restore(self) -> int:
loaded_state = {
"action": "last",
"maximize_score": True,
"tracker": self.accs,
"optimizer": self.optimizer.state_dict(),
}
self.trained_model_cb(
uid=self.uid, epoch=-1, model=self.model, state=loaded_state
Both methods make use of the call_back
function that should be passed to the trainer from
the TrainedModel
table.
The important difference is the epoch count that is passed to call_back
here.
A count of -1
signals the function that we want to retrieve the last checkpoint if there is one.
A positive count on the other hand signals that we are in the process of training and want to save the
current state as a checkpoint.
Finally, we also have to update the training procedure itself to call self.restore()
before the training starts
and self.save()
after every epoch.
epoch = loaded_state.get("epoch", -1) + 1
return epoch
def train(self) -> Tuple[float, Tuple[List[float], int], Dict]:
if hasattr(tqdm, "_instances"):
tqdm._instances.clear() # To have tqdm output without line-breaks between steps
torch.manual_seed(self.seed)
start_epoch = self.restore()
for epoch in range(start_epoch, self.epochs):
print(f"Epoch {epoch}")
predicted_correct = 0
total = 0
for x, y in tqdm(self.trainloader):
p, t = self.train_loop(x, y)
predicted_correct += p
total += t
self.accs.append(100.0 * predicted_correct / total)
Once we have a trainer that supports the checkpointing feature, all we need to do is to switch from TrainedModel
table to TrainedModelChkpt
.
from nnfabrik.templates.checkpoint import TrainedModelChkptBase, my_checkpoint
Checkpoint = my_checkpoint(nnfabrik)
@nnfabrik.schema
class TrainedModelChkpt(TrainedModelChkptBase):
table_comment = "My Trained models with checkpointing"
nnfabrik = nnfabrik
checkpoint_table = Checkpoint
Now this table can be used just as TrainedModel
, i.e. we can simply populate it.
TrainedModelChkpt.populate()
The training state will be saved automatically in the Checkpoint
-table after each epoch
and can be retrieved from there should the training be interrupted.
Thus, after an interruption, we can just clear the error state with:
# delete all jobs in error state:
(schema.jobs & "status='error'").delete()
And then we just restart the training, by calling populate()
again.