Conversation
|
@qgallouedec could you test this PR (do We should probably add a warning in the doc about the minimum pytorch version? (or in the code) |
|
Not only the pytest failed, but it caused a Python Fatal Error: Don't know what it is. I will investigate. |
|
Well, I'm pretty sure the problem comes from the fact that the observation is transposed before passing into the CNN of the feature extractor, and this seems to cause some more bugs: pytorch/pytorch#81557 To reproduce: from stable_baselines3 import A2C
from stable_baselines3.common.envs import FakeImageEnv
env = FakeImageEnv()
model = A2C("CnnPolicy", env).learn(250)It causes fatal error in this line: without traceback, but with this error message: But more generally, there are still some features missing, such as support for the multinomial distribution (pytorch/pytorch#80760) for SB3 to work fully on the mps device So we still have to be a bit patient. |
|
Thanks for testing =) |
|
Pytorch 1.13 is out. MPS is still not fully supported and causes bugs in SB3. |
|
@qgallouedec , can you please provide which Ops are missing ? |
Is this still happening in latest nightly cc @qgallouedec ? |
|
With the latest nightly: % /Users/quentingallouedec/stable-baselines3/env/bin/python /Users/quentingallouedec/stable-baselines3/test_mps.py
[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
Traceback (most recent call last):
File "/Users/quentingallouedec/stable-baselines3/test_mps.py", line 5, in <module>
model = A2C("CnnPolicy", env).learn(250)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 193, in learn
return super().learn(
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
actions, values, log_probs = self.policy(obs_tensor)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1427, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 576, in forward
log_prob = distribution.log_prob(actions)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 279, in log_prob
return self.distribution.log_prob(actions)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/categorical.py", line 123, in log_prob
self._validate_sample(value)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample
valid = support.check(value)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/constraints.py", line 257, in check
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
NotImplementedError: The operator 'aten::remainder.Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.EDIT: tested with PyTorch 2.0.0.dev20221220 |
Its in PR. Will try to priortize the merge. |
|
@qgallouedec how is the support with PyTorch 2.1.0? |
|
The number of errors decreases. Here's one a them: Is double precision a feature of sb3 or should single precision be forced systematically? |
I think we don't really support float64... (mainly to avoid issues when using CUDA) |
|
If you need someone to test something please tell me I could with my Mac because this PR is there for a while now and nobody comes with a solution or a review ... |
|
@tty666 thank you for the proposal. Feel free to test and provide your feedback if any. As far as I remember, there are still some issues related to dtype (float64 instead of float32), see #951 (comment). As soon as all the CI passes, we can consider this PR as ready to be merged |
|
Any news regarding this PR? Is someone working on it? |
|
|
Hello! I just tried this out, out of curiosity and it seems to work. The small snippet above and another project I have been working on recently work very similarly with and without MPS. I can see GPU going to 100% with asitop and no crashes. Performance-wise it's not as good as we might expect but that might related to my particular use-case. |
|
Hi. I see the tests are still failing. I'll try to give a bit more details on my setup. First, I'm running a MacBook Pro M1 Pro. The test from yesterday was running with Python 3.12. This morning, I cloned the repo, switched to the feat/mps-support branch, created a Python 3.11 venv and ran |
|
hi 👋 i would like to help move this pr forward, i see there hasnt been much progress in past few months, i have an m1 mac studio where i'm testing this branch with this setup:
can someone point me in the right direction for the changes that i need to do to make the tests pass? i seen in this pr only 3 files have been changed but i didn't find examples fixes of these issues edit: i tried my best to do things with common sense and fixed all tests, have a look at this pr #2005 |
Description
Add support for MPS device (uses it if available) and save cloudpickle version (important to debug saving/loading issues).
DO NOT MERGE: this PR must be tested on a MPS device first
closes #914
Motivation and Context
Types of changes
Checklist:
make format(required)make check-codestyleandmake lint(required)make pytestandmake typeboth pass. (required)make doc(required)Note: You can run most of the checks using
make commit-checks.Note: we are using a maximum length of 127 characters per line