Callbacks¶
Callbacks enable you, or the users of your code, to add new behavior to the training loop without needing to modify the source code.
Add a callback interface to your loop¶
Suppose we want to enable anyone to run some arbitrary code at the end of a training iteration. Here is how that gets done in Fabric:
class MyCallback:
def on_train_batch_end(self, loss, output):
# Here, put any code you want to run at the end of a training step
...
from lightning.fabric import Fabric
# The code of a callback can live anywhere, away from the training loop
from my_callbacks import MyCallback
# Add one or several callbacks:
fabric = Fabric(callbacks=[MyCallback()])
...
for iteration, batch in enumerate(train_dataloader):
...
fabric.backward(loss)
optimizer.step()
# Let a callback add some arbitrary processing at the appropriate place
# Give the callback access to some variables
fabric.call("on_train_batch_end", loss=loss, output=...)
As you can see, the code inside the callback method is completely decoupled from the trainer code. This enables flexibility in extending the loop in arbitrary ways.
Exercise: Implement a callback that computes and prints the time to complete an iteration.
Multiple callbacks¶
The callback system is designed to easily run multiple callbacks at the same time. You can pass a list to Fabric:
# Add multiple callback implementations in a list
callback1 = LearningRateMonitor()
callback2 = Profiler()
fabric = Fabric(callbacks=[callback1, callback2])
# Let Fabric call the implementations (if they exist)
fabric.call("any_callback_method", arg1=..., arg2=...)
# fabric.call is the same as doing this
callback1.any_callback_method(arg1=..., arg2=...)
callback2.any_callback_method(arg1=..., arg2=...)
The call()
calls the callback objects in the order they were given to Fabric.
Not all objects registered via Fabric(callbacks=...)
must implement a method with the given name.
The ones that have a matching method name will get called.
The different callbacks can have different method signatures. Fabric automatically filters keyword arguments based on each callback’s function signature, allowing callbacks with different signatures to work together seamlessly.
class TrainingMetricsCallback:
def on_train_epoch_end(self, train_loss):
print(f"Training loss: {train_loss:.4f}")
class ValidationMetricsCallback:
def on_train_epoch_end(self, val_accuracy):
print(f"Validation accuracy: {val_accuracy:.4f}")
class ComprehensiveCallback:
def on_train_epoch_end(self, epoch, **kwargs):
print(f"Epoch {epoch} complete with metrics: {kwargs}")
fabric = Fabric(
callbacks=[TrainingMetricsCallback(), ValidationMetricsCallback(), ComprehensiveCallback()]
)
# Each callback receives only the arguments it can handle
fabric.call("on_train_epoch_end", epoch=5, train_loss=0.1, val_accuracy=0.95, learning_rate=0.001)
Next steps¶
Callbacks are a powerful tool for building a Trainer. See a real example of how they can be integrated in our Trainer template based on Fabric: