7

I want to solve a sequence-to-sequence text generation task (e.g. question answering, language translation, etc.).

For the purposes of this question, you may assume that I already have the input part already handled. (I already have a tensor of dimensions batch_size x num_input_tokens x input_dim representing the input sequences. Also, all input sequences in my problem are of the same length, so no masking is required on the input side of things).

Now, I want to generate the output sequences using nn.TransformerDecoder. I'm aware of Pytorch's official tutorial SEQUENCE-TO-SEQUENCE MODELING WITH NN.TRANSFORMER AND TORCHTEXT. Unfortunately, the official tutorial doesn't meet my needs, for the following reasons:

  • nn.TransformerDecoder is not used in the example.
  • The example is about language modeling, not text generation. There is no forward loop that generates text word by word.

I've searched around the web and I've found a few things, but nothing like a simple and minimal working example that directly applies to my problem setting. Concretely, on the output side of things I need the following:

  • I want to generate output sequences in batch. I've found codes on GitHub where people appear to be doing text generation, but they do it for a single sequence at a time, not a batch of multiple sequences.
  • The output sequences may have different lengths.
  • I want to train my model with the teacher-forcing strategy and batches of multiple sequences. Given that in training I know the lengths of the sequences in advance, you may assume that I already have my batches padded with zeroes. However, I still need to figure out how to implement the forward function of my model, with a generation loop that uses nn.TransformerDecoder. Basically, I need to figure out how to iterate word-wise over my batch of output sequences, masking out the future words in each step (so that the model doesn't cheat by trivially predicting the next words).
  • Then, I need a similar forward function for inference mode. I need to figure out how to implement the generation loop to do basically the same as in training mode, except that instead of teacher-forcing I want to implement greedy search (i.e. use the tokens with highest predicted probability at iteration i as the next input for iteration i+1).

I already know how to do all this using LSTMs. Below you can see the forward function of a model that I implemented in the past to do exactly what I just said with an LSTM. The same forward function is used for both training and inference, depending on the value of the variable 'mode':

  def forward(
      self,
      image_local_features,
      question_vectors,
      answers=None,
      max_answer_length=None,
      mode='train',
  ):
    if mode == 'train':
      batch_size, max_answer_length = answers.shape
      assert answers is not None
    else:
      batch_size = image_local_features.size(0)
      assert max_answer_length is not None
    
    y = self.embedding_table(self.start_idx).expand(batch_size, -1)
    o = torch.zeros(batch_size, self.hidden_size).to(DEVICE)
    h = self.W_h(question_vectors)
    c = self.W_c(question_vectors)

    if mode == 'train':
      answer_embeddings = self.embedding_table(answers.permute(1,0))
      assert answer_embeddings.shape == (max_answer_length, batch_size, self.embed_size)

    output = []

    for t in range(max_answer_length):
      y_bar = torch.cat((y,o),1)
      assert y_bar.shape == (batch_size, self.embed_size + self.hidden_size)
      assert h.shape == (batch_size, self.hidden_size)
      assert c.shape == (batch_size, self.hidden_size)
      h, c = self.lstm_cell(y_bar, (h, c))
      e = (self.W_attn(image_local_features) * h.unsqueeze(1)).sum(-1)
      att = torch.softmax(e,-1)
      a = (image_local_features * att.unsqueeze(2)).sum(1)
      assert a.shape == (batch_size, self.image_local_feat_size)
      u = torch.cat((a,h),1)
      assert u.shape == (batch_size, self.hidden_size + self.image_local_feat_size)
      v = self.W_u(u)
      o = self.dropout(torch.tanh(v))
      assert o.shape == (batch_size, self.hidden_size)
      output.append(self.W_vocab(o))
      if mode == 'train':
        y = answer_embeddings[t] # teacher-forcing
      else:
        y = self.embedding_table(torch.argmax(output[t], 1)) # greedy search
      assert y.shape == (batch_size, self.embed_size)

    output = torch.stack(output, 1)
    assert output.shape == (batch_size, max_answer_length, self.vocab_size)
    return output

Another way to phrase my question would be: how can I reimplement what I did with LSTMs using nn.TransformerDecoder instead?

Any minimal working / hello world example that shows how to do batch training and batch inference with nn.TransformerDecoder for text generation will be very appreciated.


Note: alternatively, if there is a straightforward way of accomplishing the same with an out-of-the-box solution from hugginface, that would be awesome too.

Pablo Messina
  • 377
  • 2
  • 10

1 Answers1

10

After a Googling around, I think this tutorial may suit your needs.

However, it seems you have a misconception about the Transformer decoder: in training mode there is no iteration at all. While LSTM-based decoders are autoregressive by nature, Transformers are not. Instead, all predictions are generated at once based on the real target tokens (i.e. teacher forcing). To train a Transformer decoder to later be used autoregressively, we use the self-attention masks, to ensure that each prediction only depends on the previous tokens, despite having access to all tokens. You can have a look at the Annotated Transformer tutorial in its Training loop section to see how they do it.

Another difference between LSTMs and Transformers is positional encodings, which are used by Transformers to be able to know the position of each token.

Regarding inference time, the easiest approach is to implement greedy decoding (e.g. this), where at each timestep you simply take the most probable token. This decoding strategy, however, will probably give poor results (e.g. the typical token repetitions). A better option is beam search, where at each timestep you keep the most probable K partially decoded sequences, although it is more complex to implement and I have not found any implementation online meant for nn.TransformerDecoder; maybe you can have a look at OpenNMT's implementation.

noe
  • 22,074
  • 1
  • 43
  • 70
  • Thanks noe for the resources. I have a question regarding the training phase. If my encoder is quite expensive and I want to reuse the encoder output instead of re-running the full encoder again for each word, is that possible? The training loop you linked to runs the full model again and again. So, if I have a batch of sentences of length 100, the encoder would run 100 times, instead of just running once and then re-using the computation while decoding. Is it possible to have a single forward for the encoder and multiple forwards for the decoder (one forward per word)? – Pablo Messina Apr 16 '21 at 17:05
  • Note: by expensive encoder, I mean something like processing text + images. I want to encode the expensive input just once and then decode the output sequences word by word with teacher-forcing in training. That's why I thought of a forward function that runs the encoder once, and then a for-loop that runs the transformer decoder word by word. Does this make sense? – Pablo Messina Apr 16 '21 at 17:06
  • Should I ask this as a separate question? – Pablo Messina Apr 16 '21 at 17:13
  • [The training loop I linked](https://nlp.seas.harvard.edu/2018/04/03/attention.html#training-loop) does not run the full model again and again, it performs a single forward pass (no iterations whatsoever) for each batch of sentences. At inference time, the encoder output is also computed only once, and used for each of the timesteps and, actually, in many Transformer decoder implementations the past hidden states are cached at inference time over subsequent timesteps. What makes you think the encoder is computed more than once per batch? – noe Apr 16 '21 at 17:15
  • It runs once, yeah, but for a given src_mask and a given trg_mask. The issue is, if your output sequences have length 100, you would need 100 different trg_mask to simulate the 100 generation steps, so in practice you multiple your training instance by the number of words per output sentence, unless I'm misunderstanding the inner workings of the transformer decoder (please enlighten me if that's the case). – Pablo Messina Apr 16 '21 at 17:19
  • Because you need to mask all except the first word, then all except the first two words, and so on an so forth. So you end up with as many masks (and training instances) as the longest sentence, correct? – Pablo Messina Apr 16 '21 at 17:22
  • No, not correct, there is a single mask for the whole batch. Check [this question](https://datascience.stackexchange.com/q/90290/14675) and the answer. – noe Apr 16 '21 at 17:27
  • You can visualize how the mask looks like in the [annotated transformer tutorial](https://nlp.seas.harvard.edu/2018/04/03/attention.html#decoder) – noe Apr 16 '21 at 17:29
  • Ohh I see, yeah, you are right. So in a single forward pass you simulate all the generation steps for each sentence of the batch in parallel using a squared matrix as the mask, right? – Pablo Messina Apr 16 '21 at 17:39
  • Yes, that's right. – noe Apr 16 '21 at 17:40
  • Also, take into account that in normal transformers masking is only applied in the decoder, not on the encoder, so there is no `src_mask` (you mentioned it in one of your comments). – noe Apr 16 '21 at 17:43
  • In [this training loop](https://nlp.seas.harvard.edu/2018/04/03/attention.html#training-loop) there is this line: out = model.forward(batch.src, batch.trg, **batch.src_mask**, batch.trg_mask) – Pablo Messina Apr 16 '21 at 17:47
  • Ahh, I see, sorry for the misunderstanding, that mask is the padding mask (makes the transformer only attend to non-padding tokens). I was referring to self-attention masking. – noe Apr 16 '21 at 17:48
  • Perfect. Only one last question: self-attention masking (the square matrix) only applies to training, right? At inference time it would not make sense, since you are generating the output from scratch. – Pablo Messina Apr 16 '21 at 17:54
  • That depends on the implementation. If the decoder is implemented naively and, for each timestep, it recomputes all past hidden states again, then you still need the self-attention masking to avoid past hidden states to be recomputed using future tokens. If the already computed hidden states are cached and reused when computing future token predictions, then you don't need the masking. – noe Apr 16 '21 at 17:58