amago.nets.actor_critic#

Actor and critic output modules.

Functions

gammas_as_input_seq(gammas, batch_size, length)

Classes

Actor(state_dim, action_dim, discrete, gammas)

Actor output head architecture.

BaseActorHead(state_dim, action_dim, ...)

BaseCriticHead(state_dim, action_dim, ...)

NCritics(state_dim, action_dim, discrete, ...)

Critic output head architecture.

NCriticsTwoHot(state_dim, action_dim, ...[, ...])

Critic output head architecture.

PopArtLayer(gammas[, beta, init_nu, enabled])

PopArt value normalization.

ResidualActor(state_dim, action_dim, ...[, ...])

Actor output head with residual blocks.

class Actor(state_dim, action_dim, discrete, gammas, n_layers=2, d_hidden=256, activation='leaky_relu', dropout_p=0.0, continuous_dist_type=<class 'amago.nets.policy_dists.TanhGaussian'>)[source]#

Bases: BaseActorHead

Actor output head architecture.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A (small) MLP that maps the output of the TrajEncoder to a distribution over actions.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • discrete (bool) – Whether the action space is discrete

  • gammas (Tensor) – List of gamma values to use for the multi-gamma actor

  • n_layers (int) – Number of layers in the MLP. Defaults to 2.

  • d_hidden (int) – Dimension of hidden layers in the MLP. Defaults to 256.

  • activation (str) – Activation function to use in the MLP. Defaults to “leaky_relu”.

  • dropout_p (float) – Dropout rate to use in the MLP. Defaults to 0.0.

  • continuous_dist_type (Type[PolicyOutput]) – Type of continuous distribution to use if applicable. Must be a PolicyOutput. Defaults to TanhGaussian.

actor_network_forward(state, log_dict=None)[source]#
Return type:

Tensor

class BaseActorHead(state_dim, action_dim, discrete, gammas, continuous_dist_type)[source]#

Bases: Module, ABC

abstract actor_network_forward(state, log_dict=None)[source]#
Return type:

Tensor

forward(state, log_dict=None)[source]#

Compute an action distribution from a state representation.

Parameters:

state (Tensor) – The “state” sequence (the output of the TrajEncoder) (Batch, Length, state_dim)

Return type:

Distribution

Returns:

The action distribution. Type varies according to the output of PolicyOutput (e.g. Discrete or TanhGaussian). Always a pytorch distribution (e.g., Categorical) where sampled actions would have shape (Batch, Length, Gammas, action_dim).

class BaseCriticHead(state_dim, action_dim, discrete, gammas, num_critics)[source]#

Bases: Module, ABC

abstract critic_network_forward(state, action, log_dict=None)[source]#
Return type:

Tensor

forward(state, action, log_dict=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class NCritics(state_dim, action_dim, discrete, gammas, num_critics, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu')[source]#

Bases: BaseCriticHead

Critic output head architecture.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A (small) ensemble of MLPs that maps the state and action to a value estimate.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • discrete (bool) – Whether the action space is discrete

  • gammas (Tensor) – List of gamma values to use for the multi-gamma critic

  • num_critics (int) – Number of critics in the ensemble. Defaults to 4.

  • d_hidden (int) – Dimension of hidden layers in the MLP. Defaults to 256.

  • n_layers (int) – Number of layers in the MLP. Defaults to 2.

  • dropout_p (float) – Dropout rate to use in the MLP. Defaults to 0.0.

  • activation (str) – Activation function to use in the MLP. Defaults to “leaky_relu”.

critic_network_forward(state, action, log_dict=None)[source]#

Compute a value estimate from a state and action.

Parameters:
  • state (Tensor) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).

  • action (Tensor) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.

Return type:

Tensor

Returns:

The value estimate with shape (K, Batch, Length, num_critics, Gammas, 1).

class NCriticsTwoHot(state_dim, action_dim, gammas, discrete, num_critics, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu', min_return=None, max_return=None, output_bins=128, use_symlog=True)[source]#

Bases: BaseCriticHead

Critic output head architecture.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A (small) ensemble of MLPs that maps the state and action to a value estimate in the form of a categorical distribution over bins.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • gammas (Tensor) – List of gamma values to use for the multi-gamma critic

  • num_critics (int) – Number of critics in the ensemble. Defaults to 4.

  • d_hidden (int) – Dimension of hidden layers in the MLP. Defaults to 256.

  • n_layers (int) – Number of layers in the MLP. Defaults to 2.

  • dropout_p (float) – Dropout rate to use in the MLP. Defaults to 0.0.

  • activation (str) – Activation function to use in the MLP. Defaults to “leaky_relu”.

  • min_return (float | None) – Minimum return value. If not set, defaults to a very negative value (-100_000).

  • max_return (float | None) – Maximum return value. If not set, defaults to a very positive value (100_000).

  • output_bins (int) – Number of bins in the categorical distribution. Defaults to 128.

  • use_symlog (bool) – Whether to use a symlog transformation on the value estimates. Defaults to True.

Note

The default bin settings (wide range, lots of bins, symlog transformation) follow Dreamer-V3 in picking a range that does not demand domain-specific tuning. It may be more sample efficient to use tighter bounds, in which case the unintuitive spacing created by use_symlog may be turned off. However, note that the min_return and max_return do not compensate for Agent.reward_multiplier. For example, if the highest possible return in an env is 1, and reward multiplier is 10, then a tuned max_return might be 10 but should never be 1. More discussion on bin settings in AMAGO-2 Appendix A.

bin_dist_to_raw_vals(bin_dist)[source]#

Convert a categorical distribution over bins to a scalar.

Parameters:

bin_dist (Categorical) – The categorical distribution over bins (output of forward).

Return type:

Tensor

Returns:

The scalar value.

critic_network_forward(state, action, log_dict=None)[source]#

Compute a categorical distribution over bins from a state and action.

Parameters:
  • state (Tensor) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).

  • action (Tensor) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.

Return type:

Categorical

Returns:

The categorical distribution over bins with shape (K, Batch, Length, num_critics, output_bins).

raw_vals_to_labels(raw_td_target)[source]#

Convert a scalar to a categorical distribution over bins.

Just a torch port of the dreamerv3/jaxutils.py implementation.

Parameters:

raw_td_target (Tensor) – The scalar value.

Return type:

Tensor

Returns:

A two-hot encoded categorical distribution over bins.

class PopArtLayer(gammas, beta=0.0005, init_nu=100.0, enabled=True)[source]#

Bases: Module

PopArt value normalization.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Shifts value estimates according to a moving average and helps the outputs of the critic to compensate for the distribution shift. (https://arxiv.org/abs/1809.04474)

Parameters:
  • gammas (int) – Number of gamma values in the critic.

  • beta (float) – The beta parameter for the moving average. Defaults to 5e-4.

  • init_nu (float) – The initial nu parameter. Defaults to 100.0 following a recommendation in the PopArt paper.

  • enabled (no gin) – If False, this layer is a no-op. Defaults to True. Cannot be configured by gin. Instead, use Agent.use_popart.

forward(x, normalized=True)[source]#

Modify the value estimate according to the PopArt layer.

Applies normalization or denormalization to value estimates using PopArt’s moving average statistics. When normalized=True, scales and shifts values using the current statistics to normalize them. When normalized=False, maps normalized values back to the original scale of the environment.

Parameters:
  • x (Tensor) – Value estimate to modify

  • normalized (bool) – Whether to normalize (True) or denormalize (False) the values

Return type:

Tensor

Returns:

Modified value estimate in either normalized or denormalized form

normalize_values(val)[source]#

Get normalized (Q) values

Return type:

Tensor

property sigma: Tensor#
to(device)[source]#

Move to another torch device.

update_stats(val, mask)[source]#

Update the moving average statistics.

Parameters:
  • val (Tensor) – The value estimate.

  • mask (Tensor) – A mask that is 0 where value estimates should be ignored (e.g., from padded timesteps).

Return type:

None

class ResidualActor(state_dim, action_dim, discrete, gammas, feature_dim=256, residual_ff_dim=512, residual_blocks=2, activation='leaky_relu', normalization='layer', dropout_p=0.0, continuous_dist_type=<class 'amago.nets.policy_dists.TanhGaussian'>)[source]#

Bases: BaseActorHead

Actor output head with residual blocks.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Based on BRO https://arxiv.org/pdf/2405.16158v1, which recommends similar hparams to our exsiting defaults.

Parameters:
  • state_dim (int) – Dimension of the “state” space (which is the output of the TrajEncoder)

  • action_dim (int) – Dimension of the action space

  • discrete (bool) – Whether the action space is discrete

  • gammas (Tensor) – List of gamma values to use for the multi-gamma actor

  • feature_dim (int) – Dimension of the embedding between residual blocks (analogous to d_model in a Transformer). Defaults to 256.

  • residual_ff_dim (int) – Hidden dimension of residual blocks (analogous to d_ff in a Transformer). Defaults to 512.

  • residual_blocks (int) – Number of residual blocks. Defaults to 2.

  • activation (str) – Activation function to use in the MLPs. Defaults to “leaky_relu”.

  • normalization (str) – Normalization to use in the residual blocks. Defaults to “layer” (LayerNorm).

  • dropout_p (float) – Dropout rate to use in the initial linear layers. Defaults to 0.0.

  • continuous_dist_type (Type[PolicyOutput]) – Type of continuous distribution to use if applicable. Must be a PolicyOutput. Defaults to TanhGaussian.

actor_network_forward(state, log_dict=None)[source]#
Return type:

Tensor

gammas_as_input_seq(gammas, batch_size, length)[source]#
Return type:

Tensor