amago.nets.actor_critic#
Actor and critic output modules.
Functions
|
Classes
|
Actor output head architecture. |
|
Critic output head architecture. |
|
Critic output head architecture. |
|
PopArt value normalization. |
- class Actor(state_dim, action_dim, discrete, gammas, n_layers=2, d_hidden=256, activation='leaky_relu', dropout_p=0.0, continuous_dist_type=<class 'amago.nets.policy_dists.TanhGaussian'>)[source]#
Bases:
Module
Actor output head architecture.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A (small) MLP that maps the output of the TrajEncoder to a distribution over actions.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacediscrete (
bool
) – Whether the action space is discretegammas (
Tensor
) – List of gamma values to use for the multi-gamma actorn_layers (
int
) – Number of layers in the MLP. Defaults to 2.d_hidden (
int
) – Dimension of hidden layers in the MLP. Defaults to 256.activation (
str
) – Activation function to use in the MLP. Defaults to “leaky_relu”.dropout_p (
float
) – Dropout rate to use in the MLP. Defaults to 0.0.continuous_dist_type (
Type
[PolicyOutput
]) – Type of continuous distribution to use if applicable. Must be aamago.nets.policy_dists.PolicyOutput
. Defaults to TanhGaussian.
- forward(state, log_dict=None)[source]#
Compute an action distribution from a state representation.
- Parameters:
state (
Tensor
) – The “state” sequence (the output of the TrajEncoder) (Batch, Length, state_dim)- Return type:
Distribution
- Returns:
The action distribution. Type varies according to the output of
PolicyOutput
(e.g.Discrete
orTanhGaussian
). Always a pytorch distribution (e.g.,Categorical
) where sampled actions would have shape (Batch, Length, Gammas, action_dim).
- class NCritics(state_dim, action_dim, discrete, gammas, num_critics=4, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu')[source]#
Bases:
Module
Critic output head architecture.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A (small) ensemble of MLPs that maps the state and action to a value estimate.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacediscrete (
bool
) – Whether the action space is discretegammas (
Tensor
) – List of gamma values to use for the multi-gamma criticnum_critics (
int
) – Number of critics in the ensemble. Defaults to 4.d_hidden (
int
) – Dimension of hidden layers in the MLP. Defaults to 256.n_layers (
int
) – Number of layers in the MLP. Defaults to 2.dropout_p (
float
) – Dropout rate to use in the MLP. Defaults to 0.0.activation (
str
) – Activation function to use in the MLP. Defaults to “leaky_relu”.
- forward(state, action, log_dict=None)[source]#
Compute a value estimate from a state and action.
- Parameters:
state (
Tensor
) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).action (
Tensor
) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.
- Return type:
Tensor
- Returns:
The value estimate with shape (K, Batch, Length, num_critics, Gammas, 1).
- class NCriticsTwoHot(state_dim, action_dim, gammas, num_critics=4, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu', min_return=None, max_return=None, output_bins=128, use_symlog=True)[source]#
Bases:
Module
Critic output head architecture.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A (small) ensemble of MLPs that maps the state and action to a value estimate in the form of a categorical distribution over bins.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacegammas (
Tensor
) – List of gamma values to use for the multi-gamma criticnum_critics (
int
) – Number of critics in the ensemble. Defaults to 4.d_hidden (
int
) – Dimension of hidden layers in the MLP. Defaults to 256.n_layers (
int
) – Number of layers in the MLP. Defaults to 2.dropout_p (
float
) – Dropout rate to use in the MLP. Defaults to 0.0.activation (
str
) – Activation function to use in the MLP. Defaults to “leaky_relu”.min_return (
float
|None
) – Minimum return value. If not set, defaults to a very negative value (-100_000).max_return (
float
|None
) – Maximum return value. If not set, defaults to a very positive value (100_000).output_bins (
int
) – Number of bins in the categorical distribution. Defaults to 128.use_symlog (
bool
) – Whether to use a symlog transformation on the value estimates. Defaults to True.
Note
The default bin settings (wide range, lots of bins, symlog transformation) follow Dreamer-V3 in picking a range that does not demand domain-specific tuning. It may be more sample efficient to use tighter bounds, in which case the unintuitive spacing created by use_symlog may be turned off. However, note that the min_return and max_return do not compensate for Agent.reward_multiplier. For example, if the highest possible return in an env is 1, and reward multiplier is 10, then a tuned max_return might be 10 but should never be 1. More discussion on bin settings in AMAGO-2 Appendix A.
- bin_dist_to_raw_vals(bin_dist)[source]#
Convert a categorical distribution over bins to a scalar.
- Parameters:
bin_dist (
Categorical
) – The categorical distribution over bins (output offorward
).- Return type:
Tensor
- Returns:
The scalar value.
- forward(state, action, log_dict=None)[source]#
Compute a categorical distribution over bins from a state and action.
- Parameters:
state (
Tensor
) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).action (
Tensor
) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.
- Return type:
Categorical
- Returns:
The categorical distribution over bins with shape (K, Batch, Length, num_critics, output_bins).
- raw_vals_to_labels(raw_td_target)[source]#
Convert a scalar to a categorical distribution over bins.
Just a torch port of the
dreamerv3/jaxutils.py
implementation.- Parameters:
raw_td_target (
Tensor
) – The scalar value.- Return type:
Tensor
- Returns:
A two-hot encoded categorical distribution over bins.
- class PopArtLayer(gammas, beta=0.0005, init_nu=100.0, enabled=True)[source]#
Bases:
Module
PopArt value normalization.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Shifts value estimates according to a moving average and helps the outputs of the critic to compensate for the distribution shift. (https://arxiv.org/abs/1809.04474)
- Parameters:
gammas (
int
) – Number of gamma values in the critic.beta (
float
) – The beta parameter for the moving average. Defaults to 5e-4.init_nu (
float
) – The initial nu parameter. Defaults to 100.0 following a recommendation in the PopArt paper.enabled (no gin) – If False, this layer is a no-op. Defaults to True. Cannot be configured by gin. Instead, use
Agent.use_popart
.
- forward(x, normalized=True)[source]#
Modify the value estimate according to the PopArt layer.
Applies normalization or denormalization to value estimates using PopArt’s moving average statistics. When normalized=True, scales and shifts values using the current statistics to normalize them. When normalized=False, maps normalized values back to the original scale of the environment.
- Parameters:
x (
Tensor
) – Value estimate to modifynormalized (
bool
) – Whether to normalize (True) or denormalize (False) the values
- Return type:
Tensor
- Returns:
Modified value estimate in either normalized or denormalized form
- property sigma: Tensor#