RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800

#18
by g-ronimo - opened

Thank you for this model!

Any idea how to resolve this when finetuning (QLoRA) gemma-7B with FA2 on a 3090 ?

/home/g/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
  File "/home/g/gemma-ft/qlora-OA.py", line 262, in <module>
    trainer.train()
  File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
  File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2911, in training_step
    self.accelerator.backward(loss)
  File "/home/g/accelerate_fork/src/accelerate/accelerator.py", line 1966, in backward
    loss.backward(**kwargs)
  File "/home/g/.local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/g/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 531, in backward
    _flash_attn_backward(
  File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 131, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
Traceback (most recent call last):
  File "/home/g/gemma-ft/qlora-OA.py", line 262, in <module>
    trainer.train()
  File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
  File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2911, in training_step
    self.accelerator.backward(loss)
  File "/home/g/accelerate_fork/src/accelerate/accelerator.py", line 1966, in backward
    loss.backward(**kwargs)
  File "/home/g/.local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/g/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 531, in backward
    _flash_attn_backward(
  File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 131, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800

same error. It works on my instance for a mistral but not gemma

Google org

It's just that the Head dm of this model is bigger 😭 so another kernel is required it seems

Google org

https://x.com/tri_dao/status/1760458183066472556?s=20 It was fixed! Upgrade flash attention!

g-ronimo changed discussion status to closed

Sign up or log in to comment