[FIXED] Keras – no good way to stop and resume training?

Issue

After a lot of research, it seems like there is no good way to properly stop and resume training using a Tensorflow 2 / Keras model. This is true whether you are using model.fit() or using a custom training loop.

There seem to be 2 supported ways to save a model while training:

  1. Save just the weights of the model, using model.save_weights() or save_weights_only=True with tf.keras.callbacks.ModelCheckpoint. This seems to be preferred by most of the examples I’ve seen, however it has a number of major issues:

    • The optimizer state is not saved, meaning training resumption will not be correct.
    • Learning rate schedule is reset – this can be catastrophic for some models.
    • Tensorboard logs go back to step 0 – making logging essentually useless unless complex workarounds are implemented.
  2. Save the entire model, optimizer, etc. using model.save() or save_weights_only=False. The optimizer state is saved (good) but the following issues remain:

    • Tensorboard logs still go back to step 0
    • Learning rate schedule is still reset (!!!)
    • It is impossible to use custom metrics.
    • This doesn’t work at all when using a custom training loop – custom training loops use a non-compiled model, and saving/loading a non-compiled model doesn’t seem to be supported.

The best workaround I’ve found is to use a custom training loop, manually saving the step. This fixes the tensorboard logging, and the learning rate schedule can be fixed by doing something like keras.backend.set_value(model.optimizer.iterations, step). However, since a full model save is off the table, the optimizer state is not preserved. I can see no way to save the state of the optimizer independently, at least without a lot of work. And messing with the LR schedule as I’ve done feels messy as well.

Am I missing something? How are people out there saving/resuming using this API?

Solution

tf.keras.callbacks.experimental.BackupAndRestore API for resuming training from interruptions has been added for tensorflow>=2.3. It works great in my experience.

Reference:
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore

Answered By – yanp

Answer Checked By – Marie Seifert (Easybugfix Admin)

Leave a Reply

(*) Required, Your email will not be published