Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation Example of Truncated BPTT is not working; self.optimizer.step() makes no sense #20517

Open
simon-bachhuber opened this issue Dec 22, 2024 · 2 comments
Labels
docs Documentation related

Comments

@simon-bachhuber
Copy link

simon-bachhuber commented Dec 22, 2024

📚 Documentation

The example of TBPTT
https://lightning.ai/docs/pytorch/stable/common/tbptt.html
contains a couple of weird lines with self.optimizer.step() and self.optimizer.zero_grad()

Also, shouldn't one use self.manual_backward instead of self.backward ?

Also, in another documentation page you state that calling optimizer.step right before backward is preferred and good practice, yet you don't do it here

It would make more sense to write

    # 2. Remove the `hiddens` argument
    def training_step(self, batch, batch_idx):

        # 3. Split the batch in chunks along the time dimension
        split_batches = split_batch(batch, self.truncated_bptt_steps)

        batch_size = 10
        hidden_dim = 20
        hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
        # get optimizer
        optimizer = self.optimizers()
        for split_batch in range(split_batches):
            # 4. Perform the optimization in a loop
            loss, hiddens = self.my_rnn(split_batch, hiddens)
            optimizer.zero_grad()
            self.manual_backward(loss)
            optimizer.step()
            
            # 5. "Truncate"
            hiddens = hiddens.detach()

        # 6. Remove the return of `hiddens`
        # Returning loss in manual optimization is not needed
        return None

cc @Borda

@simon-bachhuber simon-bachhuber added docs Documentation related needs triage Waiting to be triaged by maintainers labels Dec 22, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 29, 2024

cc @chualanagit

@lantiga lantiga removed the needs triage Waiting to be triaged by maintainers label Dec 31, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 31, 2024

@simon-bachhuber thanks for catching this.

Regarding the use of manual_backward vs backward, that's surely a mistake.

Regarding your other comment,

calling optimizer.step right before backward is preferred and good practice

do you mean zero_grad? That is also correct, although in this particular case it is less severe, in that the trainer calls zero_grad at the beginning of the training stage, so gradients will end up being correct. Having said that, the example needs to be fixed 100%.

Would you like to send a quick PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation related
Projects
None yet
Development

No branches or pull requests

2 participants