amago.nets.policy_dists#

Stochastic policy output distributions.

Functions

softplus_bounded_positive(x, low[, high])

Map network output activations to a positive value (to parameterize a distribution).

tanh_bounded_positive(x, low, high)

Map network output activations to a positive value (to parameterize a distribution).

Classes

Beta(d_action[, alpha_low, alpha_high, ...])

Generates a Beta distribution rescaled to [-1, 1].

Discrete(d_action[, clip_prob_low, ...])

Generates a categorical distribution over actions.

DiscreteLikeContinuous(categorical[, ...])

Wrapper around Categorical used by MultiTaskAgent.

GMM(d_action[, gmm_modes, std_low, ...])

Generates a Gaussian Mixture Model with a tanh transform.

Multibinary(d_action)

Multi-binary action space support.

PolicyOutput(d_action)

Abstract base class for mapping network outputs to a distribution over actions.

TanhGaussian(d_action[, std_low, std_high, ...])

Generates a multivariate normal with a tanh transform to sample in [-1, 1].

class Beta(d_action, alpha_low=1.0001, alpha_high=None, beta_low=1.0001, beta_high=None, std_activation=<function softplus_bounded_positive>, clip_actions_on_log_prob=(-0.99, 0.99))[source]#

Bases: PolicyOutput

Generates a Beta distribution rescaled to [-1, 1].

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

(https://proceedings.mlr.press/v70/chou17a/chou17a.pdf)

Parameters:
  • d_action (int) – Dimension of the action space.

  • alpha_low (float) – Minimum value of alpha. Default is 1.0001.

  • alpha_high (float | None) – Maximum value of alpha. Default is None.

  • beta_low (float) – Minimum value of beta. Default is 1.0001.

  • beta_high (float | None) – Maximum value of beta. Default is None.

  • std_activation (Callable[[Tensor, float, float], Tensor]) – Activation function to produce a valid standard deviation from the raw network output.

  • clip_actions_on_log_prob (tuple[float, float]) – Tuple of floats that clips the actions before computing dist.log_prob(action). Addresses numerical stability issues when computing log_probs at the boundary of the action space.

Note

alpha_low > 1 and beta_low > 1 keeps the distribution unimodal. (https://mathlets.org/mathlets/beta-distribution/)

property actions_differentiable: bool#

Does the output distribution have rsample?

Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”

forward(vec, log_dict=None)[source]#

Maps the output of the actor network to a distribution over actions.

Parameters:
  • vec (Tensor) – Output of the actor network

  • log_dict (dict | None) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.

Return type:

_ShiftedBeta

Returns:

A torch.distributions.Distribution that at least has a log_prob() and sample(), and would be expected to have rsample() if self.actions_differentiable is True.

property input_dimension: int#

Required input dimension for this policy distribution.

This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?

property is_discrete: bool#

Whether the action space is discrete.

class Discrete(d_action, clip_prob_low=0.001, clip_prob_high=0.99)[source]#

Bases: PolicyOutput

Generates a categorical distribution over actions.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Returns a thin wrapper around torch Categorical that unsqueezes the last dimension of sample() actions to be 1.

Parameters:
  • d_action (int) – Dimension of the action space.

  • clip_prob_low (float) – Clips action probabilities to this value before renormalizing. Default is 0.001.

  • clip_prob_high (float) – Clips action probabilities to this value before renormalizing. Default is 0.99, which is left for backwards compatibility but is now thought to be too conservative. .999 or 1.0 is fine.

property actions_differentiable: bool#

Does the output distribution have rsample?

Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”

forward(vec, log_dict=None)[source]#

Maps the output of the actor network to a distribution over actions.

Parameters:
  • vec (Tensor) – Output of the actor network

  • log_dict (dict | None) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.

Return type:

_Categorical

Returns:

A torch.distributions.Distribution that at least has a log_prob() and sample(), and would be expected to have rsample() if self.actions_differentiable is True.

property input_dimension: int#

Required input dimension for this policy distribution.

This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?

property is_discrete: bool#

Whether the action space is discrete.

class DiscreteLikeContinuous(categorical, gumbel_softmax_temperature=0.5)[source]#

Bases: object

Wrapper around Categorical used by MultiTaskAgent.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Lets us use discrete actions in a continuous actor-critic setup where the critic takes action vectors as input and outputs a scalar value.

Categorial –> OneHotCategorical + rsample() as F.gumbel_softmax with straight-through hard sampling.

Parameters:
  • categorical (Categorical | _Categorical) – A Categorical distribution.

  • gumbel_softmax_temperature (float) – Temperature for the Gumbel-Softmax trick that enables rsample(). Default is 0.5.

entropy()[source]#
Return type:

Tensor

log_prob(action)[source]#
Return type:

Tensor

property logits: Tensor#
property probs: Tensor#
rsample()[source]#
Return type:

Tensor

sample(*args, **kwargs)[source]#
Return type:

Tensor

class GMM(d_action, gmm_modes=5, std_low=0.0001, std_high=None, std_activation=<function softplus_bounded_positive>)[source]#

Bases: PolicyOutput

Generates a Gaussian Mixture Model with a tanh transform.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A more expressive policy than TanhGaussian, but output does not support rsample() or the DPG -Q(s, a ~ pi) loss. Often used in offline or imitation learning (IL) settings. Heavily based on robomimic’s robot IL setup.

Parameters:
  • d_action (int) – Dimension of the action space.

  • gmm_modes (int) – Number of modes in the GMM. Default is 5.

  • std_low (float) – Minimum standard deviation. Default is 1e-4.

  • std_high (float | None) – Maximum standard deviation. Default is None.

  • std_activation (Callable[[Tensor, float, float], Tensor]) – Activation function to produce a std from the raw network output.

property actions_differentiable: bool#

Does the output distribution have rsample?

Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”

forward(vec, log_dict=None)[source]#

Maps the output of the actor network to a distribution over actions.

Parameters:
  • vec (Tensor) – Output of the actor network

  • log_dict (dict | None) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.

Return type:

_TanhWrappedDistribution

Returns:

A torch.distributions.Distribution that at least has a log_prob() and sample(), and would be expected to have rsample() if self.actions_differentiable is True.

property input_dimension: int#

Required input dimension for this policy distribution.

This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?

property is_discrete: bool#

Whether the action space is discrete.

class Multibinary(d_action)[source]#

Bases: PolicyOutput

Multi-binary action space support.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

property actions_differentiable: bool#

Does the output distribution have rsample?

Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”

forward(vec, log_dict=None)[source]#

Maps the output of the actor network to a distribution over actions.

Parameters:
  • vec (Tensor) – Output of the actor network

  • log_dict (dict | None) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.

Return type:

Bernoulli

Returns:

A torch.distributions.Distribution that at least has a log_prob() and sample(), and would be expected to have rsample() if self.actions_differentiable is True.

property input_dimension: int#

Required input dimension for this policy distribution.

This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?

property is_discrete: bool#

Whether the action space is discrete.

class PolicyOutput(d_action)[source]#

Bases: ABC

Abstract base class for mapping network outputs to a distribution over actions.

Actor networks use a PolicyOutput to produce a distribution over actions that is compatible with the Agent’s loss function.

Pretends to be a torch.nn.Module (forward == __call__) but is not. Has no parameters and can be swapped without breaking checkpoints.

Parameters:

d_action (int) – Dimension of the action space.

abstract property actions_differentiable: bool#

Does the output distribution have rsample?

Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”

abstract forward(vec, log_dict=None)[source]#

Maps the output of the actor network to a distribution over actions.

Parameters:
  • vec (Tensor) – Output of the actor network

  • log_dict (dict | None) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.

Return type:

Distribution

Returns:

A torch.distributions.Distribution that at least has a log_prob() and sample(), and would be expected to have rsample() if self.actions_differentiable is True.

abstract property input_dimension: int#

Required input dimension for this policy distribution.

This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?

abstract property is_discrete: bool#

Whether the action space is discrete.

class TanhGaussian(d_action, std_low=0.006737946999085467, std_high=7.38905609893065, std_activation=<function tanh_bounded_positive>, clip_actions_on_log_prob=(-0.99, 0.99))[source]#

Bases: PolicyOutput

Generates a multivariate normal with a tanh transform to sample in [-1, 1].

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • d_action (int) – Dimension of the action space.

  • std_low (float) – Minimum standard deviation. Default is exp(-5.0).

  • std_high (float) – Maximum standard deviation. Default is exp(2.0).

  • std_activation (Callable[[Tensor, float, float], Tensor]) – Activation function to produce a valid standard deviation from the raw network output.

  • clip_actions_on_log_prob (tuple[float, float]) – Tuple of floats that clips the actions before computing dist.log_prob(action). Adresses numerical stability issues when computing log_probs at or near saturation points of Tanh transforms. Default is (-0.99, 0.99).

property actions_differentiable: bool#

Does the output distribution have rsample?

Used to answer the question: “can we optimize -Q(s, a ~ pi) as an actor loss?”

forward(vec, log_dict=None)[source]#

Maps the output of the actor network to a distribution over actions.

Parameters:
  • vec (Tensor) – Output of the actor network

  • log_dict (dict | None) – If None, this is not a log step and any log value computation can be skipped. If provided, any data added added to the dict will be automatically logged. Defaults to None.

Return type:

_SquashedNormal

Returns:

A torch.distributions.Distribution that at least has a log_prob() and sample(), and would be expected to have rsample() if self.actions_differentiable is True.

property input_dimension: int#

Required input dimension for this policy distribution.

This is used to determine the output of the actor network. How many values does the actor network need to produce to parameterize this policy distribution?

property is_discrete: bool#

Whether the action space is discrete.

softplus_bounded_positive(x, low, high=None)[source]#

Map network output activations to a positive value (to parameterize a distribution).

Uses softplus to output positive values > low (required) and < high (if provided).

Return type:

Tensor

tanh_bounded_positive(x, low, high)[source]#

Map network output activations to a positive value (to parameterize a distribution).

Uses tanh to scale values between low and high.

Return type:

Tensor