Skip to content

feat: add MPS (Apple Silicon) support for inference#3

Open
ZimengXiong wants to merge 1 commit intoQwenLM:mainfrom
ZimengXiong:feature/mps-support
Open

feat: add MPS (Apple Silicon) support for inference#3
ZimengXiong wants to merge 1 commit intoQwenLM:mainfrom
ZimengXiong:feature/mps-support

Conversation

@ZimengXiong
Copy link
Copy Markdown

Support MPS backend for inference

  • Updated device selection logic to prefer CUDA if available, then MPS, then CPU in both src/app.py and src/tool/edit_rgba_image.py. [1] [2]
  • Set tensor data type to torch.bfloat16 for CUDA/MPS devices and torch.float32 for CPU. [1] [2]
  • Updated model pipeline initialization and .to() calls to use the new device and data type logic in both files. [1] [2]
  • Modified the generator in the infer function to use the selected device rather than hardcoding 'cuda'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant