fix: truncate finished output in stream_generate

#74
Files changed (1) hide show
  1. modeling_chatglm.py +1 -0
modeling_chatglm.py CHANGED
@@ -1404,6 +1404,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1404
  next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1405
  else:
1406
  next_tokens = torch.argmax(probs, dim=-1)
 
1407
 
1408
  # update generated ids, model inputs, and length for next step
1409
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 
1404
  next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1405
  else:
1406
  next_tokens = torch.argmax(probs, dim=-1)
1407
+ next_tokens = torch.where(unfinished_sequences.bool(), next_tokens, eos_token_id[0])
1408
 
1409
  # update generated ids, model inputs, and length for next step
1410
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)