Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/common/distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Probability distributions used for the different action spaces:

- `CategoricalDistribution` -> Discrete
- `DiagGaussianDistribution` -> Box (continuous actions)
- `BetaDistribution` -> Box (continuous actions) when `use_beta=True`
- `StateDependentNoiseDistribution` -> Box (continuous actions) when `use_sde=True`

% - ``MultiCategoricalDistribution`` -> MultiDiscrete
Expand All @@ -19,9 +20,17 @@ For instance, in the case of discrete actions. The policy network outputs probab
of taking each action. The `CategoricalDistribution` allows sampling from it,
computes the entropy, the log probability (`log_prob`) and backpropagate the gradient.

In the case of continuous actions, a Gaussian distribution is used. The policy network outputs
In the case of continuous actions, a Gaussian distribution is used by default. The policy network outputs
mean and (log) std of the distribution (assumed to be a `DiagGaussianDistribution`).

Alternatively, a `BetaDistribution` can be used for continuous actions in bounded spaces
by passing `policy_kwargs=dict(use_beta=True)`. The Beta distribution has bounded support on [0, 1],
so sampled actions naturally respect bounds without clipping. Actions are rescaled from [0, 1]
to the environment's action space `[low, high]`. The policy network outputs raw α and β parameters,
which are passed through softplus + 1 to ensure α, β ≥ 1 (unimodal regime).
The idea was first introduced by [Chou et al. (2017)](https://proceedings.mlr.press/v70/chou17a.html)
and further explored in [*The Beta Policy for Continuous Control Reinforcement Learning*](https://arxiv.org/abs/2111.02202).

```{eval-rst}
.. automodule:: stable_baselines3.common.distributions
:members:
Expand Down
1 change: 1 addition & 0 deletions docs/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
### New Features:

- Added official support for Python 3.13
- Added `BetaDistribution` for continuous actions in bounded spaces, based on [Chou et al. (2017)](https://proceedings.mlr.press/v70/chou17a.html) and [arXiv:2111.02202](https://arxiv.org/abs/2111.02202). Can be enabled via `policy_kwargs=dict(use_beta=True)` for PPO and A2C. Actions are sampled from a Beta distribution (in [0, 1]) and rescaled to the action space bounds.

### Bug Fixes:

Expand Down
5 changes: 5 additions & 0 deletions docs/modules/a2c.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ When using A2C models trained with `use_sde=True`, the automatic noise resetting
For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`.
:::

## Beta Distribution for Continuous Actions

A2C supports using a **Beta distribution** for continuous actions via `policy_kwargs=dict(use_beta=True)`.
See the [PPO documentation](../modules/ppo.md#beta-distribution-for-continuous-actions) for details.

## Results

### Atari Games
Expand Down
31 changes: 31 additions & 0 deletions docs/modules/ppo.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ When using PPO models trained with `use_sde=True`, the automatic noise resetting
For continuous control tasks, it is recommended to use deterministic behavior during inference (`deterministic=True`). If you need stochastic behavior during inference, you must manually reset the noise by calling `model.policy.reset_noise(env.num_envs)` at appropriate intervals based on your desired `sde_sample_freq`.
:::

## Beta Distribution for Continuous Actions

For continuous action spaces (`Box`), PPO can use a **Beta distribution** instead of the default Gaussian.
The idea of using the Beta distribution for continuous control was first introduced by
[Chou et al. (2017)](https://proceedings.mlr.press/v70/chou17a.html) and further explored
in [*The Beta Policy for Continuous Control Reinforcement Learning*](https://arxiv.org/abs/2111.02202).

Unlike a Gaussian, the Beta distribution has **bounded support** on [0, 1], which means sampled actions
naturally stay within bounds without requiring clipping. Actions are then rescaled to the environment's
action space `[low, high]`. This avoids the boundary effects and bias introduced by clipping
unbounded Gaussian samples, which can be particularly problematic in environments with frequent
actions near the edges of the action space.

The policy network outputs raw α and β parameters for each action dimension.
These are passed through a softplus activation and shifted by +1 to ensure α, β ≥ 1,
keeping the distribution in the **unimodal regime**.

:::{note}
`use_beta=True` and `use_sde=True` are **mutually exclusive** and cannot be combined.
:::

### Example

```python
from stable_baselines3 import PPO

# Use a Beta distribution instead of Gaussian for continuous actions
model = PPO("MlpPolicy", "Pendulum-v1", policy_kwargs=dict(use_beta=True), verbose=1)
model.learn(total_timesteps=100_000)
```

## Results

### Atari Games
Expand Down
105 changes: 102 additions & 3 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch as th
from gymnasium import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from torch.distributions import Bernoulli, Beta, Categorical, Normal
from torch.distributions import Distribution as TorchDistribution

from stable_baselines3.common.preprocessing import get_action_dim
Expand All @@ -20,6 +20,7 @@
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
SelfBetaDistribution = TypeVar("SelfBetaDistribution", bound="BetaDistribution")
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")


Expand Down Expand Up @@ -426,6 +427,92 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.
return actions, log_prob


class BetaDistribution(Distribution):
"""
Beta distribution for continuous actions in a bounded space.

Actions are sampled from a Beta distribution and then rescaled to [low, high].
The network outputs two heads (alpha and beta parameters), which are passed through
softplus to ensure positivity, with +1 added to keep them >= 1 (unimodal regime).

References: Chou et al. "Improving Stochastic Policy Gradients in Continuous Control with
Deep Reinforcement Learning using the Beta Distribution" (ICML 2017)
https://proceedings.mlr.press/v70/chou17a.html
and https://arxiv.org/abs/2111.02202

:param action_dim: Dimension of the action space.
"""

distribution: Beta

def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it outputs both alpha and beta parameters (2 * action_dim outputs).

:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:return: action network (outputs alpha and beta concatenated)
"""
action_net = nn.Linear(latent_dim, 2 * self.action_dim)
return action_net

def proba_distribution(self: SelfBetaDistribution, alpha_beta: th.Tensor) -> SelfBetaDistribution:
"""
Create the distribution given the concatenated alpha/beta parameters.

:param alpha_beta: Concatenated raw alpha and beta outputs from the network,
shape (batch, 2 * action_dim). Will be passed through softplus + 1
to ensure alpha, beta >= 1 (unimodal).
:return: self
"""
alpha, beta = th.chunk(alpha_beta, 2, dim=-1)
# Softplus + 1 ensures alpha, beta >= 1 (unimodal Beta distribution)
alpha = th.nn.functional.softplus(alpha) + 1.0
beta = th.nn.functional.softplus(beta) + 1.0
self.distribution = Beta(alpha, beta)
return self

def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Actions are expected in [0, 1] (the raw Beta space).

:param actions:
:return:
"""
# Clamp to avoid log(0) at boundaries
# log_prob = self.distribution.log_prob(actions.clamp(1e-6, 1.0 - 1e-6))
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)

def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())

def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()

def mode(self) -> th.Tensor:
# Mode of Beta(alpha, beta) = (alpha - 1) / (alpha + beta - 2) when alpha, beta > 1
alpha = self.distribution.concentration1
beta = self.distribution.concentration0
return (alpha - 1.0) / (alpha + beta - 2.0)

def actions_from_params(self, alpha_beta: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(alpha_beta)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, alpha_beta: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(alpha_beta)
log_prob = self.log_prob(actions)
return actions, log_prob


class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Expand Down Expand Up @@ -668,22 +755,34 @@ def log_prob_correction(self, x: th.Tensor) -> th.Tensor:


def make_proba_distribution(
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: dict[str, Any] | None = None
action_space: spaces.Space,
use_sde: bool = False,
use_beta: bool = False,
dist_kwargs: dict[str, Any] | None = None,
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space

:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param use_beta: Force the use of BetaDistribution for continuous actions.
See https://arxiv.org/abs/2111.02202
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}

if isinstance(action_space, spaces.Box):
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
assert not (use_sde and use_beta), "use_sde and use_beta are mutually exclusive options for Box action spaces."
cls: type[BetaDistribution] | type[StateDependentNoiseDistribution] | type[DiagGaussianDistribution]
if use_beta:
cls = BetaDistribution
elif use_sde:
cls = StateDependentNoiseDistribution
else:
cls = DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(int(action_space.n), **dist_kwargs)
Expand Down
6 changes: 6 additions & 0 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.distributions import BetaDistribution
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
Expand Down Expand Up @@ -210,6 +211,11 @@ def collect_rollouts(
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
elif isinstance(self.policy.action_dist, BetaDistribution):
# Beta distribution outputs actions in [0, 1],
# rescale to [low, high] for the environment
low, high = self.action_space.low, self.action_space.high
clipped_actions = low + (high - low) * clipped_actions
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
Expand Down
28 changes: 26 additions & 2 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from stable_baselines3.common.distributions import (
BernoulliDistribution,
BetaDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
Distribution,
Expand Down Expand Up @@ -373,6 +374,11 @@ def predict(
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
elif hasattr(self, "action_dist") and isinstance(self.action_dist, BetaDistribution):
# Beta distribution outputs actions in [0, 1],
# rescale to [low, high] for the environment
low, high = self.action_space.low, self.action_space.high
actions = low + (high - low) * actions # type: ignore[assignment]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
Expand Down Expand Up @@ -425,6 +431,8 @@ class ActorCriticPolicy(BasePolicy):
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param use_beta: Whether to use a Beta distribution instead of a Gaussian for continuous actions.
See https://arxiv.org/abs/2111.02202. Cannot be used together with ``use_sde``.
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
Expand Down Expand Up @@ -454,6 +462,7 @@ def __init__(
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
use_beta: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
Expand Down Expand Up @@ -527,10 +536,11 @@ def __init__(
}

self.use_sde = use_sde
self.use_beta = use_beta
self.dist_kwargs = dist_kwargs

# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_beta=use_beta, dist_kwargs=dist_kwargs)

self._build(lr_schedule)

Expand All @@ -544,6 +554,7 @@ def _get_constructor_parameters(self) -> dict[str, Any]:
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
use_beta=self.use_beta,
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
Expand Down Expand Up @@ -601,7 +612,9 @@ def _build(self, lr_schedule: Schedule) -> None:
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
elif isinstance(
self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution, BetaDistribution)
):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
Expand Down Expand Up @@ -701,6 +714,9 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BetaDistribution):
# Here mean_actions are the raw alpha/beta parameters (2 * action_dim)
return self.action_dist.proba_distribution(alpha_beta=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
Expand Down Expand Up @@ -775,6 +791,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param use_beta: Whether to use a Beta distribution instead of a Gaussian for continuous actions.
See https://arxiv.org/abs/2111.02202. Cannot be used together with ``use_sde``.
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
Expand Down Expand Up @@ -804,6 +822,7 @@ def __init__(
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
use_beta: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
Expand All @@ -823,6 +842,7 @@ def __init__(
activation_fn,
ortho_init,
use_sde,
use_beta,
log_std_init,
full_std,
use_expln,
Expand All @@ -848,6 +868,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param use_beta: Whether to use a Beta distribution instead of a Gaussian for continuous actions.
See https://arxiv.org/abs/2111.02202. Cannot be used together with ``use_sde``.
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
Expand Down Expand Up @@ -877,6 +899,7 @@ def __init__(
activation_fn: type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
use_beta: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
Expand All @@ -896,6 +919,7 @@ def __init__(
activation_fn,
ortho_init,
use_sde,
use_beta,
log_std_init,
full_std,
use_expln,
Expand Down
Loading