Configure#
The Experiment
has lots of other kwargs to control things like the update:data ratio, optimization, and logging.
For additional control over the training process, use gin
. First, note that the general format of an AMAGO training script is:
from amago import Experiment
from amago.envs import AMAGOEnv
# define the env
make_env = lambda: AMAGOEnv(...)
# create a dataset
dataset = DiskTrajDataset(...)
# make the main conceptual choices
tstep_encoder_type = CNNTstepEncoder
traj_encoder_type = TformerTrajEncoder
agent_type = Agent
exploration_wrapper_type = EpsilonGreedy
experiment = Experiment(
dataset=dataset,
# assign lots of *callables*, not instances, to the experiment
make_train_env=make_env,
make_val_env=make_env,
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
agent_type=agent_type,
exploration_wrapper_type=exploration_wrapper_type,
...
)
experiment.start()
experiment.learn()
experiment.evaluate_test(make_env)
We choose the AMAGOEnv
, RLDataset
, and TstepEncoder
because they are problem-specific. We also pick the TrajEncoder
because it is the key feature of a sequence model agent.
start()
is going to create an Agent
based on the environment and our other choices. This follows a strict rule:
Important
Anytime AMAGO needs to initialize/call a class/method, it infers the positional args (based on the environment and our other choices),
but leaves every keyword argument set to its default value. gin lets us edit those values without editing the source code, and keeps track of the settings we used on wandb
.
The Examples show how almost every application of AMAGO looks the same aside from some minor gin
configuration. gin
can be complicated, but AMAGO tries to make it hard to get wrong:
Tip
If something is @gin.configurable
(there will be a note at the top of the documentation), it means that Experiment
wlll only ever call/construct it with default keyword arguments,
and there is no other place where those values should be set or will be overridden. The only exceptions are Experiment
and RLDataset
,
which are explicitly constructed by the user before training begins, but are configurable for convenience.
For example, let’s say we want to switch the CNNTstepEncoder
to use a larger IMPALA
architecture with twice as many channels as usual. The API reference for CNNTstepEncoder
looks like this:
- class CNNTstepEncoder(obs_space, rl2_space, cnn_type=<class 'amago.nets.cnn.NatureishCNN'>, channels_first=False, img_features=256, rl2_features=12, d_output=256, out_norm='layer', activation='leaky_relu', hide_rl2s=False, drqv2_aug=False, aug_pct_of_batch=0.75, obs_key='observation')[source]
Bases:
TstepEncoder
A simple CNN-based TstepEncoder.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Useful for pixel-based environments. Currently only supports the case where observations are a single image without additional state arrays.
- Parameters:
obs_space (
Space
) – Environment observation space.rl2_space (
Space
) – A gym space declaring the shape of previous action and reward features. This is created by the AMAGOEnv wrapper.cnn_type (
Type
[CNN
]) – The type ofnets.cnn.CNN
to use. Defaults tonets.cnn.NatureishCNN
(the small DQN CNN).channels_first (
bool
) – Whether the image is in channels-first format. Defaults to False.img_features (
int
) – Linear map the output of the CNN to this many features. Defaults to 256.rl2_features (
int
) – Linear map the previous action and reward to this many features. Defaults to 12.d_output (
int
) – The output dimension of a layer that fuses the img_features and rl2_features. Defaults to 256.out_norm (
str
) – The normalization layer to use. Seenets.ff.Normalization
for options. Defaults to “layer”.activation (
str
) – The activation function to use. Seenets.utils.activation_switch
for options. Defaults to “leaky_relu”.hide_rl2s (
bool
) – Whether to ignore the previous action and reward features (but otherwise keep the same parameter count and layer dimensions).drqv2_aug (
bool
) – Quick-apply the default DrQv2 image augmentation. Applies random crops to `aug_pct_of_batch`% of every batch during training. Currently requires square images. Defaults to False.aug_pct_of_batch (
float
) – The percentage of every batch to apply DrQv2 augmentation to, ifdrqv2_aug
is True. Defaults to 0.75.obs_key (
str
) – The key in the observation space that contains the image. Defaults to “observation”, which is the default created by AMAGOEnv when the original observation space is not a dict.
Following our rule, obs_space
and rl2_space
are going be determined for us, but nothing will try to set cnn_type
, so it will default to NatureishCNN
. The IMPALAishCNN
looks like this:
- class IMPALAishCNN(img_shape, channels_first, activation, cnn_block_depths=[16, 32, 32], post_group_norm=True)[source]
Bases:
CNN
CNN architecture from IMPALA.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- Parameters:
img_shape (
tuple
[int
]) – Shape of the image (H, W, C) or (C, H, W).channels_first (
bool
) – Whether the image is in channels-first format.activation (
str
) – Activation function to use. Seeamago.nets.utils.activation_switch
.cnn_block_depths (
list
[int
]) – List of ints representing the number of output channels for each convolutional block. Length defines the number of residual blocks. Defaults to [16, 32, 32].post_group_norm (
bool
) – Whether to use group normalization after each convolutional block. Defaults to True.
So we can change the cnn_block_depths
and post_group_norm
by editing these values, but this would not be the place to change the activation
.
The most orgnaized way to set gin values is with .gin
config files. But we can also do this with:
from amago.nets.cnn import IMPALAishCNN
from amago.cli_utils import use_config
config = {
"amago.nets.tstep_encoders.CNNTstepEncoder.cnn_type" : IMPALAishCNN,
"IMPALAishCNN.cnn_block_depths" : [32, 64, 64],
}
# changes the default values
use_config(config)
experiment = Experiment(
tstep_encoder_type=CNNTstepEncoder,
...
)
As a more complicated example, let’s say we want to use a TformerTrajEncoder
with 6 layers of dimension 512, 16 heads, and sliding window attention with a window size of 256.
- class TformerTrajEncoder(tstep_dim, max_seq_len, d_model=256, n_heads=8, d_ff=1024, n_layers=3, dropout_ff=0.05, dropout_emb=0.05, dropout_attn=0.0, dropout_qkv=0.0, activation='leaky_relu', norm='layer', pos_emb='fixed', sigma_reparam=True, normformer_norms=True, head_scaling=True, attention_type=<class 'amago.nets.transformer.FlashAttention'>)[source]
Bases:
TrajEncoder
Transformer Trajectory Encoder.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A pre-norm Transformer decoder-only model that processes sequences of timestep embeddings.
- Parameters:
tstep_dim (
int
) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.max_seq_len (
int
) – Maximum sequence length of the model. The max context length of the model during training.d_model (
int
) – Dimension of the main residual stream and output. Defaults to 256.n_heads (
int
) – Number of self-attention heads. Each head has dimension d_model/n_heads. Defaults to 8.d_ff (
int
) – Dimension of feed-forward network in residual blocks. Defaults to 4*d_model.n_layers (
int
) – Number of Transformer layers. Defaults to 3.dropout_ff (
float
) – Dropout rate for linear layers within Transformer. Defaults to 0.05.dropout_emb (
float
) – Dropout rate for input embedding (combined input sequence and position embeddings passed to Transformer). Defaults to 0.05.dropout_attn (
float
) – Dropout rate for attention matrix. Defaults to 0.00.dropout_qkv (
float
) – Dropout rate for query/key/value projections. Defaults to 0.00.activation (
str
) – Activation function. Defaults to “leaky_relu”.norm (
str
) – Normalization function. Defaults to “layer” (LayerNorm).pos_emb (
str
) – Position embedding type. “fixed” (default) uses sinusoidal embeddings, “learned” uses trainable embeddings per timestep.causal – Whether to use causal attention mask. Defaults to True.
sigma_reparam (
bool
) – Whether to use \(\sigma\)-reparam feed-forward layers from https://arxiv.org/abs/2303.06296. Defaults to True.normformer_norms (
bool
) – Whether to use extra norm layers from NormFormer (https://arxiv.org/abs/2110.09456). Always uses pre-norm Transformer.head_scaling (
bool
) – Whether to use head scaling from NormFormer. Defaults to True.attention_type (
type
[SelfAttention
]) – Attention layer type. Defaults to transformer.FlashAttention. transformer.VanillaAttention provided as backup. New types can inherit from transformer.SelfAttention.
- class SlidingWindowFlexAttention(causal, dropout, window_size=gin.REQUIRED)[source]
Bases:
FlexAttention
A more useful test of FlexAttention that implements a sliding window pattern for long context lengths.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.
gin.REQUIRED
is reserved for settings that are not commonly used but would be so important and task-specific that there is no good default. You’ll get an error if you use one but forget to configure it.
from amago.nets.traj_encoders import TformerTrajEncoder
from amago.nets.transformer import SlidingWindowFlexAttention
from amago.cli_utils import use_config
config = {
"TformerTrajEncoder.n_heads" : 16,
"TformerTrajEncoder.d_model" : 512,
"TformerTrajEncoder.d_ff" : 2048,
"TformerTrajEncoder.attention_type": SlidingWindowFlexAttention,
"SlidingWindowFlexAttention.window_size" : 128,
}
use_config(config)
experiment = Experiment(
traj_encoder_type=TformerTrajEncoder,
...
)
Customizing the built-in TstepEncoder
, TrajEncoder
, Agent
and ExplorationWrapper
is so common that there’s easier ways to do it in use_config
. For example, we could’ve made the changes for all the previous examples at the same time with:
from amago.cli_utils import switch_traj_encoder, switch_tstep_encoder, switch_agent, switch_exploration, use_config
from amago.nets.transformer import SlidingWindowFlexAttention
from amago.nets.cnn import IMPALAishCNN
config = {
# these are niche changes customized a level below the `TstepEncoder` / `TrajEncoder`, so we still have to specify them
"amago.nets.transformer.SlidingWindowFlexAttention.window_size" : 128,
"amago.nets.cnn.IMPALAishCNN.cnn_block_depths" : [32, 64, 64],
}
tstep_encoder_type = switch_step_encoder(config, arch="cnn", cnn_type=IMPALAishCNN)
traj_encoder_type = switch_traj_encoder(config, arch="transformer", d_model=512, d_ff=2048, n_heads=16, attention_type=SlidingWindowFlexAttention)
exploration_wrapper_type = switch_exploration(config, strategy="egreedy", eps_start=1.0, eps_end=.01, steps_anneal=200_000)
# also customize random RL details as an example
agent_type = switch_agent(config, agent="multitask", num_critics=6, gamma=.998)
use_config(config)
experiment = Experiment(
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
agent_type=agent_type,
exploration_wrapper_type=exploration_wrapper_type,
...
)
If we want to combine hardcoded changes like these with genuine .gin
files, use_config()
will take the paths.
# these changes are applied in order from left to right. if we override the same param
# in multiple configs the final one will count. making gin this complicated is usually a bad idea.
use_config(config, gin_configs=["environment_config.gin", "rl_config.gin"])
Tip
You can view your active gin config (all the active hyperparameters used by an experiment) in the checkpoint directory as config.txt
, or on wandb
in the Config
section.
A full API reference is available in the API Reference section.