-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModelCheckpointCallback is triggered by mistake after every validation stage when mannual optimization #20459
Comments
Thanks for reporting this. Is that the case with the latest master as well? |
Ok, I verified and can reproduce import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.callbacks import ModelCheckpoint
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = Transformer(
vocab_size=self.vocab_size,
nlayers=2,
nhid=4096,
ninp=1024,
nhead=8,
)
def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("val_loss", loss, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)
def train():
L.seed_everything(42)
dataset = WikiText2()
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
val_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
model = LanguageModel(vocab_size=dataset.vocab_size)
model_checkpoint = ModelCheckpoint(save_top_k=-1, every_n_epochs=2, save_last=False)
trainer = L.Trainer(
max_steps=100,
precision="bf16-true",
limit_train_batches=10,
limit_val_batches=2,
callbacks=model_checkpoint,
val_check_interval=5
)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
train() |
BTW if you change model_checkpoint = ModelCheckpoint(save_top_k=-1, every_n_epochs=2, save_last=False) to model_checkpoint = ModelCheckpoint(save_top_k=-1, every_n_epochs=2, save_on_train_epoch_end=True, save_last=False) then you'll get checkpoints saved at the right interval. However if you keep the interval as in the above snippet, you indeed get something like this:
i.e. you get a checkpoint at all validation steps, in the training epoch when you're supposed to save and not the other. Which is of course consistent with the code in |
So the resolution for this is use ModelCheckpoint(..., save_on_train_epoch_end=True) to avoid saving on In the future we could introduce a |
Bug description
I set the every_n_epochs param of ModelCheckpoint to 1 and val_check_interval of trainer to 200. The total iter of a batch is 1000. It should not save checkpoint files after the val_check. But it does.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
More info
No response
cc @tchaton @justusschock @awaelchli @Borda
The text was updated successfully, but these errors were encountered: