Is 'mps' supported for flans ?

#10
by abhishekmamdapure - opened

Looking for 'mps' support for flan models ?

Google org

Hi @abhishekmamdapure
Thanks for the issue, may I ask you what do you mean by mps support? Running flan-t5 on an M1 chip?

Yes.

I tried running flan-t5 large on M2, but it was giving garbage results.

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
mps_device = torch.device('mps')
model.to(mps_device)
input_text = "Translate to German:  My name is Abhishek"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("mps")

outputs = model.generate(input_ids,max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

Code used !

Sign up or log in to comment