amago.nets.policy_dists#
Stochastic policy output distributions.
Functions
|
Map network output activations to a positive value (to parameterize a distribution). |
|
Map network output activations to a positive value (to parameterize a distribution). |
Classes
|
Generates a Beta distribution rescaled to [-1, 1]. |
|
Generates a categorical distribution over actions. |
|
Wrapper around |
|
Generates a Gaussian Mixture Model with a tanh transform. |
|
Multi-binary action space support. |
|
Abstract base class for mapping network outputs to a distribution over actions. |
|
Generates a multivariate normal with a tanh transform to sample in [-1, 1]. |
- class Beta(d_action, alpha_low=1.0001, alpha_high=None, beta_low=1.0001, beta_high=None, std_activation=<function softplus_bounded_positive>, clip_actions_on_log_prob=(-0.99, 0.99))[source]#
Bases:
PolicyOutput
Generates a Beta distribution rescaled to [-1, 1].
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.(https://proceedings.mlr.press/v70/chou17a/chou17a.pdf)
- Parameters:
d_action (
int
) – Dimension of the action space.alpha_low (
float
) – Minimum value of alpha. Default is 1.0001.alpha_high (
float
|None
) – Maximum value of alpha. Default is None.beta_low (
float
) – Minimum value of beta. Default is 1.0001.beta_high (
float
|None
) – Maximum value of beta. Default is None.std_activation (
Callable
[[Tensor
,float
,float
],Tensor
]) – Activation function to produce a valid standard deviation from the raw network output.clip_actions_on_log_prob (
tuple
[float
,float
]) – Tuple of floats that clips the actions before computing dist.log_prob(action). Addresses numerical stability issues when computing log_probs at the boundary of the action space.
Note
alpha_low > 1 and beta_low > 1 keeps the distribution unimodal. (https://mathlets.org/mathlets/beta-distribution/)
- property actions_differentiable: bool#
Does the output distribution have
rsample
?Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”
- forward(vec, log_dict=None)[source]#
Maps the output of the actor network to a distribution over actions.
- Parameters:
vec (
Tensor
) – Output of the actor networklog_dict (
dict
|None
) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.
- Return type:
_ShiftedBeta
- Returns:
A torch.distributions.Distribution that at least has a
log_prob()
andsample()
, and would be expected to haversample()
ifself.actions_differentiable
is True.
- property input_dimension: int#
Required input dimension for this policy distribution.
This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?
- property is_discrete: bool#
Whether the action space is discrete.
- class Discrete(d_action, clip_prob_low=0.001, clip_prob_high=0.99)[source]#
Bases:
PolicyOutput
Generates a categorical distribution over actions.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Returns a thin wrapper around torch
Categorical
that unsqueezes the last dimension ofsample()
actions to be 1.- Parameters:
d_action (
int
) – Dimension of the action space.clip_prob_low (
float
) – Clips action probabilities to this value before renormalizing. Default is 0.001.clip_prob_high (
float
) – Clips action probabilities to this value before renormalizing. Default is 0.99, which is left for backwards compatibility but is now thought to be too conservative. .999 or 1.0 is fine.
- property actions_differentiable: bool#
Does the output distribution have
rsample
?Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”
- forward(vec, log_dict=None)[source]#
Maps the output of the actor network to a distribution over actions.
- Parameters:
vec (
Tensor
) – Output of the actor networklog_dict (
dict
|None
) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.
- Return type:
_Categorical
- Returns:
A torch.distributions.Distribution that at least has a
log_prob()
andsample()
, and would be expected to haversample()
ifself.actions_differentiable
is True.
- property input_dimension: int#
Required input dimension for this policy distribution.
This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?
- property is_discrete: bool#
Whether the action space is discrete.
- class DiscreteLikeContinuous(categorical, gumbel_softmax_temperature=0.5)[source]#
Bases:
object
Wrapper around
Categorical
used byMultiTaskAgent
.Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Lets us use discrete actions in a continuous actor-critic setup where the critic takes action vectors as input and outputs a scalar value.
Categorial –> OneHotCategorical +
rsample()
as F.gumbel_softmax with straight-throughhard
sampling.- Parameters:
categorical (
Categorical
|_Categorical
) – ACategorical
distribution.gumbel_softmax_temperature (
float
) – Temperature for the Gumbel-Softmax trick that enablesrsample()
. Default is 0.5.
- property logits: Tensor#
- property probs: Tensor#
- class GMM(d_action, gmm_modes=5, std_low=0.0001, std_high=None, std_activation=<function softplus_bounded_positive>)[source]#
Bases:
PolicyOutput
Generates a Gaussian Mixture Model with a tanh transform.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A more expressive policy than TanhGaussian, but output does not support
rsample()
or the DPG -Q(s, a ~ pi) loss. Often used in offline or imitation learning (IL) settings. Heavily based on robomimic’s robot IL setup.- Parameters:
d_action (
int
) – Dimension of the action space.gmm_modes (
int
) – Number of modes in the GMM. Default is 5.std_low (
float
) – Minimum standard deviation. Default is 1e-4.std_high (
float
|None
) – Maximum standard deviation. Default is None.std_activation (
Callable
[[Tensor
,float
,float
],Tensor
]) – Activation function to produce a std from the raw network output.
- property actions_differentiable: bool#
Does the output distribution have
rsample
?Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”
- forward(vec, log_dict=None)[source]#
Maps the output of the actor network to a distribution over actions.
- Parameters:
vec (
Tensor
) – Output of the actor networklog_dict (
dict
|None
) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.
- Return type:
_TanhWrappedDistribution
- Returns:
A torch.distributions.Distribution that at least has a
log_prob()
andsample()
, and would be expected to haversample()
ifself.actions_differentiable
is True.
- property input_dimension: int#
Required input dimension for this policy distribution.
This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?
- property is_discrete: bool#
Whether the action space is discrete.
- class Multibinary(d_action)[source]#
Bases:
PolicyOutput
Multi-binary action space support.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- property actions_differentiable: bool#
Does the output distribution have
rsample
?Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”
- forward(vec, log_dict=None)[source]#
Maps the output of the actor network to a distribution over actions.
- Parameters:
vec (
Tensor
) – Output of the actor networklog_dict (
dict
|None
) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.
- Return type:
Bernoulli
- Returns:
A torch.distributions.Distribution that at least has a
log_prob()
andsample()
, and would be expected to haversample()
ifself.actions_differentiable
is True.
- property input_dimension: int#
Required input dimension for this policy distribution.
This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?
- property is_discrete: bool#
Whether the action space is discrete.
- class PolicyOutput(d_action)[source]#
Bases:
ABC
Abstract base class for mapping network outputs to a distribution over actions.
Actor networks use a PolicyOutput to produce a distribution over actions that is compatible with the Agent’s loss function.
Pretends to be a torch.nn.Module (
forward
==__call__
) but is not. Has no parameters and can be swapped without breaking checkpoints.- Parameters:
d_action (
int
) – Dimension of the action space.
- abstract property actions_differentiable: bool#
Does the output distribution have
rsample
?Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”
- abstract forward(vec, log_dict=None)[source]#
Maps the output of the actor network to a distribution over actions.
- Parameters:
vec (
Tensor
) – Output of the actor networklog_dict (
dict
|None
) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.
- Return type:
Distribution
- Returns:
A torch.distributions.Distribution that at least has a
log_prob()
andsample()
, and would be expected to haversample()
ifself.actions_differentiable
is True.
- abstract property input_dimension: int#
Required input dimension for this policy distribution.
This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?
- abstract property is_discrete: bool#
Whether the action space is discrete.
- class TanhGaussian(d_action, std_low=0.006737946999085467, std_high=7.38905609893065, std_activation=<function tanh_bounded_positive>, clip_actions_on_log_prob=(-0.99, 0.99))[source]#
Bases:
PolicyOutput
Generates a multivariate normal with a tanh transform to sample in [-1, 1].
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- Parameters:
d_action (
int
) – Dimension of the action space.std_low (
float
) – Minimum standard deviation. Default is exp(-5.0).std_high (
float
) – Maximum standard deviation. Default is exp(2.0).std_activation (
Callable
[[Tensor
,float
,float
],Tensor
]) – Activation function to produce a valid standard deviation from the raw network output.clip_actions_on_log_prob (
tuple
[float
,float
]) – Tuple of floats that clips the actions before computing dist.log_prob(action). Adresses numerical stability issues when computing log_probs at or near saturation points of Tanh transforms. Default is (-0.99, 0.99).
- property actions_differentiable: bool#
Does the output distribution have
rsample
?Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”
- forward(vec, log_dict=None)[source]#
Maps the output of the actor network to a distribution over actions.
- Parameters:
vec (
Tensor
) – Output of the actor networklog_dict (
dict
|None
) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.
- Return type:
_SquashedNormal
- Returns:
A torch.distributions.Distribution that at least has a
log_prob()
andsample()
, and would be expected to haversample()
ifself.actions_differentiable
is True.
- property input_dimension: int#
Required input dimension for this policy distribution.
This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?
- property is_discrete: bool#
Whether the action space is discrete.