pranjalchitale commited on
Commit
bb1e8f4
1 Parent(s): 4c16b02

Update modeling_indictrans.py

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +6 -3
modeling_indictrans.py CHANGED
@@ -54,9 +54,12 @@ logger = logging.get_logger(__name__)
54
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
- if is_flash_attn_2_available():
58
- from flash_attn import flash_attn_func, flash_attn_varlen_func
59
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
60
 
61
 
62
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
54
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
+ try:
58
+ if is_flash_attn_2_available():
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+ except:
62
+ pass
63
 
64
 
65
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data