amago.nets.actor_critic#

Actor and critic output modules.

Functions

gammas_as_input_seq(gammas, batch_size, length)

Classes

Actor(state_dim, action_dim, discrete, gammas)

Actor output head architecture.

NCritics(state_dim, action_dim, discrete, gammas)

Critic output head architecture.

NCriticsTwoHot(state_dim, action_dim, gammas)

Critic output head architecture.

PopArtLayer(gammas[, beta, init_nu, enabled])

PopArt value normalization.

class Actor(state_dim, action_dim, discrete, gammas, n_layers=2, d_hidden=256, activation='leaky_relu', dropout_p=0.0, continuous_dist_type=<class 'amago.nets.policy_dists.TanhGaussian'>)[source]#

Bases: Module

Actor output head architecture.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A (small) MLP that maps the output of the TrajEncoder to a distribution over actions.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • discrete (bool) – Whether the action space is discrete

  • gammas (Tensor) – List of gamma values to use for the multi-gamma actor

  • n_layers (int) – Number of layers in the MLP. Defaults to 2.

  • d_hidden (int) – Dimension of hidden layers in the MLP. Defaults to 256.

  • activation (str) – Activation function to use in the MLP. Defaults to “leaky_relu”.

  • dropout_p (float) – Dropout rate to use in the MLP. Defaults to 0.0.

  • continuous_dist_type (Type[PolicyOutput]) – Type of continuous distribution to use if applicable. Must be a amago.nets.policy_dists.PolicyOutput. Defaults to TanhGaussian.

forward(state, log_dict=None)[source]#

Compute an action distribution from a state representation.

Parameters:

state (Tensor) – The “state” sequence (the output of the TrajEncoder) (Batch, Length, state_dim)

Return type:

Distribution

Returns:

The action distribution. Type varies according to the output of PolicyOutput (e.g. Discrete or TanhGaussian). Always a pytorch distribution (e.g., Categorical) where sampled actions would have shape (Batch, Length, Gammas, action_dim).

class NCritics(state_dim, action_dim, discrete, gammas, num_critics=4, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu')[source]#

Bases: Module

Critic output head architecture.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A (small) ensemble of MLPs that maps the state and action to a value estimate.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • discrete (bool) – Whether the action space is discrete

  • gammas (Tensor) – List of gamma values to use for the multi-gamma critic

  • num_critics (int) – Number of critics in the ensemble. Defaults to 4.

  • d_hidden (int) – Dimension of hidden layers in the MLP. Defaults to 256.

  • n_layers (int) – Number of layers in the MLP. Defaults to 2.

  • dropout_p (float) – Dropout rate to use in the MLP. Defaults to 0.0.

  • activation (str) – Activation function to use in the MLP. Defaults to “leaky_relu”.

forward(state, action, log_dict=None)[source]#

Compute a value estimate from a state and action.

Parameters:
  • state (Tensor) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).

  • action (Tensor) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.

Return type:

Tensor

Returns:

The value estimate with shape (K, Batch, Length, num_critics, Gammas, 1).

class NCriticsTwoHot(state_dim, action_dim, gammas, num_critics=4, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu', min_return=None, max_return=None, output_bins=128, use_symlog=True)[source]#

Bases: Module

Critic output head architecture.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A (small) ensemble of MLPs that maps the state and action to a value estimate in the form of a categorical distribution over bins.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • gammas (Tensor) – List of gamma values to use for the multi-gamma critic

  • num_critics (int) – Number of critics in the ensemble. Defaults to 4.

  • d_hidden (int) – Dimension of hidden layers in the MLP. Defaults to 256.

  • n_layers (int) – Number of layers in the MLP. Defaults to 2.

  • dropout_p (float) – Dropout rate to use in the MLP. Defaults to 0.0.

  • activation (str) – Activation function to use in the MLP. Defaults to “leaky_relu”.

  • min_return (float | None) – Minimum return value. If not set, defaults to a very negative value (-100_000).

  • max_return (float | None) – Maximum return value. If not set, defaults to a very positive value (100_000).

  • output_bins (int) – Number of bins in the categorical distribution. Defaults to 128.

  • use_symlog (bool) – Whether to use a symlog transformation on the value estimates. Defaults to True.

Note

The default bin settings (wide range, lots of bins, symlog transformation) follow Dreamer-V3 in picking a range that does not demand domain-specific tuning. It may be more sample efficient to use tighter bounds, in which case the unintuitive spacing created by use_symlog may be turned off. However, note that the min_return and max_return do not compensate for Agent.reward_multiplier. For example, if the highest possible return in an env is 1, and reward multiplier is 10, then a tuned max_return might be 10 but should never be 1. More discussion on bin settings in AMAGO-2 Appendix A.

bin_dist_to_raw_vals(bin_dist)[source]#

Convert a categorical distribution over bins to a scalar.

Parameters:

bin_dist (Categorical) – The categorical distribution over bins (output of forward).

Return type:

Tensor

Returns:

The scalar value.

forward(state, action, log_dict=None)[source]#

Compute a categorical distribution over bins from a state and action.

Parameters:
  • state (Tensor) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).

  • action (Tensor) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.

Return type:

Categorical

Returns:

The categorical distribution over bins with shape (K, Batch, Length, num_critics, output_bins).

raw_vals_to_labels(raw_td_target)[source]#

Convert a scalar to a categorical distribution over bins.

Just a torch port of the dreamerv3/jaxutils.py implementation.

Parameters:

raw_td_target (Tensor) – The scalar value.

Return type:

Tensor

Returns:

A two-hot encoded categorical distribution over bins.

class PopArtLayer(gammas, beta=0.0005, init_nu=100.0, enabled=True)[source]#

Bases: Module

PopArt value normalization.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Shifts value estimates according to a moving average and helps the outputs of the critic to compensate for the distribution shift. (https://arxiv.org/abs/1809.04474)

Parameters:
  • gammas (int) – Number of gamma values in the critic.

  • beta (float) – The beta parameter for the moving average. Defaults to 5e-4.

  • init_nu (float) – The initial nu parameter. Defaults to 100.0 following a recommendation in the PopArt paper.

  • enabled (no gin) – If False, this layer is a no-op. Defaults to True. Cannot be configured by gin. Instead, use Agent.use_popart.

forward(x, normalized=True)[source]#

Modify the value estimate according to the PopArt layer.

Applies normalization or denormalization to value estimates using PopArt’s moving average statistics. When normalized=True, scales and shifts values using the current statistics to normalize them. When normalized=False, maps normalized values back to the original scale of the environment.

Parameters:
  • x (Tensor) – Value estimate to modify

  • normalized (bool) – Whether to normalize (True) or denormalize (False) the values

Return type:

Tensor

Returns:

Modified value estimate in either normalized or denormalized form

normalize_values(val)[source]#

Get normalized (Q) values

Return type:

Tensor

property sigma: Tensor#
to(device)[source]#

Move to another torch device.

update_stats(val, mask)[source]#

Update the moving average statistics.

Parameters:
  • val (Tensor) – The value estimate.

  • mask (Tensor) – A mask that is 0 where value estimates should be ignored (e.g., from padded timesteps).

Return type:

None

gammas_as_input_seq(gammas, batch_size, length)[source]#
Return type:

Tensor