Hacker News new | past | comments | ask | show | jobs | submit login

You should be able to make a few small changes to support "mps".

In TrainingConfig set the device to "mps". The run training.

In sample.py modify parse_args() and add support for mps as a possible value for the --device argument.






Thanks! I'll try. I didn't bother believing that if this was developed heavily on CUDA, it was likely going to use kernels that were missing in MPS.



Consider applying for YC's Spring batch! Applications are open till Feb 11.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: