amago.nets.traj_encoders#
Seq2Seq models for long-term memory.
Classes
|
Feed-forward (memory-free) trajectory encoder. |
|
RNN (GRU) Trajectory Encoder. |
|
Mamba Trajectory Encoder. |
|
Transformer Trajectory Encoder. |
|
Abstract base class for trajectory encoders. |
- class FFTrajEncoder(tstep_dim, max_seq_len, d_model=256, d_ff=None, n_layers=1, dropout=0.0, activation='leaky_relu', norm='layer')[source]#
Bases:
TrajEncoder
Feed-forward (memory-free) trajectory encoder.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A useful tool for applying AMAGO to standard MDPs and benchmarking general RL details/hyperparamters on common benchmarks. The feed-forward architecture is designed to be close to an attention-less Transformer (residual blocks, norm, dropout, etc.). This makes it easy to create perfect 1:1 ablations of “memory vs. no memory” by only changing the TrajEncoder and without touching the
max_seq_len
, which would have the side-effect of changing the effective batch size of actor-critic learning.- Parameters:
tstep_dim – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.
max_seq_len – Maximum sequence length of the model. Any inputs will have been trimmed to this length before reaching the TrajEncoder.
d_model (
int
) – Dimension of the main residual stream and output. 1:1 with how this would be defined in a Transformer. Defaults to 256.d_ff (
int
|None
) – Hidden dim of the feed-forward network along each residual block. 1:1 with how this would be defined in a Transformer. Defaults to4 * d_model
.n_layers (
int
) – Number of residual feed-forward blocks. 1:1 with how this would be defined in a Transformer. Defaults to 1.dropout (
float
) – Dropout rate. Equivalent to the dropout paramter of feed-forward blocks in a Transformer, but is also applied to the first and last linear layers (inp –> d_model and d_model –> out). Defaults to 0.0.activation (
str
) – Activation function. Defaults to “leaky_relu”.norm (
str
) – Normalization function. Defaults to “layer” (LayerNorm).
- property emb_dim#
Defines the expected output dim of this model.
Used to infer the input dim of actor/critics.
- Returns:
The embedding dimension.
- Return type:
int
- forward(seq, time_idxs=None, hidden_state=None, log_dict=None)[source]#
Sequence model forward pass.
- Parameters:
seq – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.
time_idxs – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.
hidden_state – Architecture-specific hidden state. Defaults to None.
- Returns:
- A tuple containing:
- output_seq: [Batch, Timestep, self.emb_dim]. Output of our
seq2seq model.
- new_hidden_state: Architecture-specific hidden state. Expected to
be
None
if inputhidden_state
isNone
. Otherwise, we assume we are at inference time and that thisforward
method has handled any updates to the hidden state that were needed.
- Return type:
Tuple[torch.Tensor, Optional[Any]]
- class GRUTrajEncoder(tstep_dim, max_seq_len, d_hidden=256, n_layers=2, d_output=256, norm='layer')[source]#
Bases:
TrajEncoder
RNN (GRU) Trajectory Encoder.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- Parameters:
tstep_dim (
int
) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.max_seq_len (
int
) – Maximum sequence length of the model. Any inputs will have been trimmed to this length before reaching the TrajEncoder.d_hidden (
int
) – Dimension of the hidden state of the GRU. Defaults to 256.n_layers (
int
) – Number of layers in the GRU. Defaults to 2.d_output (
int
) – Dimension of the output linear layer after the GRU. Defaults to 256.norm (
str
) – Normalization applied after the final linear layer. Defaults to “layer” (LayerNorm).
- property emb_dim#
Defines the expected output dim of this model.
Used to infer the input dim of actor/critics.
- Returns:
The embedding dimension.
- Return type:
int
- forward(seq, time_idxs=None, hidden_state=None, log_dict=None)[source]#
Sequence model forward pass.
- Parameters:
seq – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.
time_idxs – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.
hidden_state – Architecture-specific hidden state. Defaults to None.
- Returns:
- A tuple containing:
- output_seq: [Batch, Timestep, self.emb_dim]. Output of our
seq2seq model.
- new_hidden_state: Architecture-specific hidden state. Expected to
be
None
if inputhidden_state
isNone
. Otherwise, we assume we are at inference time and that thisforward
method has handled any updates to the hidden state that were needed.
- Return type:
Tuple[torch.Tensor, Optional[Any]]
Hook to implement architecture-specific hidden state reset.
- Parameters:
hidden_state – We only expect to see hidden states that were created by
self.init_hidden_state()
.dones – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.
- Returns:
- Architecture-specific hidden state. Defaults to a
no-op:
new_hidden_state = hidden_state
.
- Return type:
Optional[Any]
- class MambaTrajEncoder(tstep_dim, max_seq_len, d_model=256, d_state=16, d_conv=4, expand=2, n_layers=3, norm='layer')[source]#
Bases:
TrajEncoder
Mamba Trajectory Encoder.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Implementation of the Mamba architecture from “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” (https://arxiv.org/abs/2312.00752).
- Parameters:
tstep_dim (
int
) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.max_seq_len (
int
) – Maximum sequence length of the model. The max context length of the model during training.d_model (
int
) – Dimension of the main residual stream and output, analogous to the d_model in a Transformer. Defaults to 256.d_state (
int
) – Dimension of the SSM in Mamba blocks. Defaults to 16.d_conv (
int
) – Dimension of the convolution layer in Mamba blocks. Defaults to 4.expand (
int
) – Expansion factor of the SSM. Defaults to 2.n_layers (
int
) – Number of Mamba blocks. Defaults to 3.norm (
str
) – Normalization function. Defaults to “layer” (LayerNorm).
References
- property emb_dim#
Defines the expected output dim of this model.
Used to infer the input dim of actor/critics.
- Returns:
The embedding dimension.
- Return type:
int
- forward(seq, time_idxs=None, hidden_state=None, log_dict=None)[source]#
Sequence model forward pass.
- Parameters:
seq (
Tensor
) – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.time_idxs (
Tensor
|None
) – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.hidden_state (
_MambaHiddenState
|None
) – Architecture-specific hidden state. Defaults to None.
- Returns:
- A tuple containing:
- output_seq: [Batch, Timestep, self.emb_dim]. Output of our
seq2seq model.
- new_hidden_state: Architecture-specific hidden state. Expected to
be
None
if inputhidden_state
isNone
. Otherwise, we assume we are at inference time and that thisforward
method has handled any updates to the hidden state that were needed.
- Return type:
Tuple[torch.Tensor, Optional[Any]]
Hook to create an architecture-specific hidden state.
Return value is passed as
TrajEncoder.forward(..., hidden_state=self.init_hidden_state(...))
when the agent begins to interact with the environment.- Parameters:
batch_size (
int
) – Number of parallel environments.device (
device
) – Device to store hidden state tensors (if applicable).
- Returns:
- Some hidden state object, or None if not applicable.
Defaults to None.
- Return type:
Optional[Any]
Hook to implement architecture-specific hidden state reset.
- Parameters:
hidden_state (
_MambaHiddenState
|None
) – We only expect to see hidden states that were created byself.init_hidden_state()
.dones (
ndarray
) – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.
- Returns:
- Architecture-specific hidden state. Defaults to a
no-op:
new_hidden_state = hidden_state
.
- Return type:
Optional[Any]
- class TformerTrajEncoder(tstep_dim, max_seq_len, d_model=256, n_heads=8, d_ff=1024, n_layers=3, dropout_ff=0.05, dropout_emb=0.05, dropout_attn=0.0, dropout_qkv=0.0, activation='leaky_relu', norm='layer', pos_emb='fixed', sigma_reparam=True, normformer_norms=True, head_scaling=True, attention_type=<class 'amago.nets.transformer.FlashAttention'>)[source]#
Bases:
TrajEncoder
Transformer Trajectory Encoder.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A pre-norm Transformer decoder-only model that processes sequences of timestep embeddings.
- Parameters:
tstep_dim (
int
) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.max_seq_len (
int
) – Maximum sequence length of the model. The max context length of the model during training.d_model (
int
) – Dimension of the main residual stream and output. Defaults to 256.n_heads (
int
) – Number of self-attention heads. Each head has dimension d_model/n_heads. Defaults to 8.d_ff (
int
) – Dimension of feed-forward network in residual blocks. Defaults to 4*d_model.n_layers (
int
) – Number of Transformer layers. Defaults to 3.dropout_ff (
float
) – Dropout rate for linear layers within Transformer. Defaults to 0.05.dropout_emb (
float
) – Dropout rate for input embedding (combined input sequence and position embeddings passed to Transformer). Defaults to 0.05.dropout_attn (
float
) – Dropout rate for attention matrix. Defaults to 0.00.dropout_qkv (
float
) – Dropout rate for query/key/value projections. Defaults to 0.00.activation (
str
) – Activation function. Defaults to “leaky_relu”.norm (
str
) – Normalization function. Defaults to “layer” (LayerNorm).pos_emb (
str
) – Position embedding type. “fixed” (default) uses sinusoidal embeddings, “learned” uses trainable embeddings per timestep.causal – Whether to use causal attention mask. Defaults to True.
sigma_reparam (
bool
) – Whether to use \(\sigma\)-reparam feed-forward layers from https://arxiv.org/abs/2303.06296. Defaults to True.normformer_norms (
bool
) – Whether to use extra norm layers from NormFormer (https://arxiv.org/abs/2110.09456). Always uses pre-norm Transformer.head_scaling (
bool
) – Whether to use head scaling from NormFormer. Defaults to True.attention_type (
type
[SelfAttention
]) – Attention layer type. Defaults to transformer.FlashAttention. transformer.VanillaAttention provided as backup. New types can inherit from transformer.SelfAttention.
- property emb_dim: int#
Defines the expected output dim of this model.
Used to infer the input dim of actor/critics.
- Returns:
The embedding dimension.
- Return type:
int
- forward(seq, time_idxs, hidden_state=None, log_dict=None)[source]#
Sequence model forward pass.
- Parameters:
seq (
Tensor
) – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.time_idxs (
Tensor
) – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.hidden_state (
TformerHiddenState
|None
) – Architecture-specific hidden state. Defaults to None.
- Returns:
- A tuple containing:
- output_seq: [Batch, Timestep, self.emb_dim]. Output of our
seq2seq model.
- new_hidden_state: Architecture-specific hidden state. Expected to
be
None
if inputhidden_state
isNone
. Otherwise, we assume we are at inference time and that thisforward
method has handled any updates to the hidden state that were needed.
- Return type:
Tuple[torch.Tensor, Optional[Any]]
Hook to create an architecture-specific hidden state.
Return value is passed as
TrajEncoder.forward(..., hidden_state=self.init_hidden_state(...))
when the agent begins to interact with the environment.- Parameters:
batch_size (
int
) – Number of parallel environments.device (
device
) – Device to store hidden state tensors (if applicable).
- Returns:
- Some hidden state object, or None if not applicable.
Defaults to None.
- Return type:
Optional[Any]
Hook to implement architecture-specific hidden state reset.
- Parameters:
hidden_state (
TformerHiddenState
|None
) – We only expect to see hidden states that were created byself.init_hidden_state()
.dones (
ndarray
) – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.
- Returns:
- Architecture-specific hidden state. Defaults to a
no-op:
new_hidden_state = hidden_state
.
- Return type:
Optional[Any]
- class TrajEncoder(tstep_dim, max_seq_len)[source]#
Bases:
Module
,ABC
Abstract base class for trajectory encoders.
An agent’s “TrajEncoder” is the sequence model in charge of mapping the output of the “TstepEncoder” for each timestep of the trajectory to the latent dimension where actor-critic learning takes place. Because the actor and critic are feed-forward networks, this is the place to add long-term memory over previous timesteps.
Note
It would not make sense for the sequence model defined here to be bi-directional or non-causal.
- Parameters:
tstep_dim (
int
) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.max_seq_len (
int
) – Maximum sequence length of the model. Any inputs will have been trimmed to this length before reaching the TrajEncoder.
- abstract property emb_dim: int#
Defines the expected output dim of this model.
Used to infer the input dim of actor/critics.
- Returns:
The embedding dimension.
- Return type:
int
- abstract forward(seq, time_idxs, hidden_state=None, log_dict=None)[source]#
Sequence model forward pass.
- Parameters:
seq (
Tensor
) – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.time_idxs (
Tensor
) – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.hidden_state (
Any
|None
) – Architecture-specific hidden state. Defaults to None.
- Returns:
- A tuple containing:
- output_seq: [Batch, Timestep, self.emb_dim]. Output of our
seq2seq model.
- new_hidden_state: Architecture-specific hidden state. Expected to
be
None
if inputhidden_state
isNone
. Otherwise, we assume we are at inference time and that thisforward
method has handled any updates to the hidden state that were needed.
- Return type:
Tuple[torch.Tensor, Optional[Any]]
Hook to create an architecture-specific hidden state.
Return value is passed as
TrajEncoder.forward(..., hidden_state=self.init_hidden_state(...))
when the agent begins to interact with the environment.- Parameters:
batch_size (
int
) – Number of parallel environments.device (
device
) – Device to store hidden state tensors (if applicable).
- Returns:
- Some hidden state object, or None if not applicable.
Defaults to None.
- Return type:
Optional[Any]
Hook to implement architecture-specific hidden state reset.
- Parameters:
hidden_state (
Any
|None
) – We only expect to see hidden states that were created byself.init_hidden_state()
.dones (
ndarray
) – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.
- Returns:
- Architecture-specific hidden state. Defaults to a
no-op:
new_hidden_state = hidden_state
.
- Return type:
Optional[Any]