amago.experiment#
Start and launch training runs (main Experiment
).
Classes
|
Build, train, and evaluate an |
- class Experiment(run_name, ckpt_base_dir, max_seq_len, dataset, tstep_encoder_type, traj_encoder_type, agent_type, val_timesteps_per_epoch, make_train_env, make_val_env, parallel_actors=12, env_mode='async', async_env_mp_context=None, exploration_wrapper_type=<class 'amago.envs.exploration.EpsilonGreedy'>, sample_actions=True, force_reset_train_envs_every=None, log_to_wandb=False, wandb_project="os.environ['AMAGO_WANDB_PROJECT']", wandb_entity="os.environ['AMAGO_WANDB_ENTITY']", wandb_group_name=None, verbose=True, log_interval=300, traj_save_len=10000000000.0, has_dset_edit_rights=True, stagger_traj_file_lengths=True, save_trajs_as='npz', padded_sampling='none', dloader_workers=6, epochs=1000, start_learning_at_epoch=0, start_collecting_at_epoch=0, train_timesteps_per_epoch=1000, train_batches_per_epoch=1000, val_interval=20, ckpt_interval=50, always_save_latest=True, always_load_latest=False, batch_size=24, batches_per_update=1, learning_rate=0.0001, critic_loss_weight=10.0, lr_warmup_steps=500, grad_clip=1.0, l2_coeff=0.001, mixed_precision='no')[source]#
Bases:
object
Build, train, and evaluate an
Agent
.Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Required
- Parameters:
run_name (
str
) – Name of the experiment. Used to create checkpoint and log directories.ckpt_base_dir (
str
) – Base directory to store checkpoints and logs. Checkpoints are saved tockpt_base_dir/run_name
.max_seq_len (
int
) – Maximum sequence length for training. Determines effective batch size (Batch Size × Sequence Length).dataset (
RLDataset
) –RLDataset
for loading training sequences.tstep_encoder_type (
type
[TstepEncoder
]) – a type ofTstepEncoder
(will be created with default kwargs — edit via gin).traj_encoder_type (
type
[TrajEncoder
]) – a type ofTrajEncoder
(will be created with default kwargs — edit via gin).agent_type (
type
[Agent
]) – a type ofAgent
(will be created with default kwargs — edit via gin).make_train_env (
callable
|Iterable
[callable
]) – Callable returning anAMAGOEnv
. If not a list, repeatedparallel_actors
times. List gives manual assignment across actors.make_val_env (
callable
|Iterable
[callable
]) – Likemake_train_env
, but only used for evaluation (trajectories never saved).val_timesteps_per_epoch (
int
) – Number of steps per parallel environment for evaluation. Determines metric sample size. Should be enough time for at least one episode to finish per actor.
Environment
- Parameters:
parallel_actors (
int
) – Number of parallel envs for batched inference. Default: 12.env_mode (
str
) –"async"
(default), wraps envs in async pool."already_vectorized"
for jax/gpu batch envs."sync"
for debug. Default: “async”.exploration_wrapper_type (
type
[ExplorationWrapper
] |None
) – Exploration wrapper for training envs. Default:EpsilonGreedy
.sample_actions (
bool
) – Whether to sample from stochastic actor during eval, or take argmax/mean. Default: True.force_reset_train_envs_every (
int
|None
) – If set, forces call toreset
every N epochs for already_vectorized envs. Default: None.async_env_mp_context (
str
|None
) – Multiprocessing spawn method forAsyncVectorEnv
(e.g.,"forkserver"
). Only relevant forenv_mode="async"
. Set to None for default method. Default: None.
Logging
- Parameters:
log_to_wandb (
bool
) – Enable or disable wandb logging. Default: False.wandb_project (
str
) – wandb project. Default:AMAGO_WANDB_PROJECT
env var.wandb_entity (
str
) – wandb entity (username/team). Default:AMAGO_WANDB_ENTITY
env var.wandb_group_name (
str
|None
) – Group runs on wandb dashboard. Default: None.verbose (
bool
) – Print tqdm bars and info to console. Default: True.log_interval (
int
) – Log extra metrics every N batches. Default: 300.padded_sampling (
str
) – Padding for sampling training subsequences. “none”, “left”, “right”, “both”. Default: “none”.dloader_workers (
int
) – Number of DataLoader workers for disk loading. Increase for compressed/large trajs.
Note
The parameters below are only relevant when doing online data collection. They determine how parallel environments write finished trajectories to disk. The
DiskTrajDataset
reads these files for training.- Parameters:
traj_save_len (
int
) – Save trajectory on episode end or after this many steps (whichever comes first). Larger values save whole trajectories. Default: large value.has_dset_edit_rights (
bool
) – Turn off for collect-only runs where another process manages the replay buffer. Default: True.stagger_traj_file_lengths (
bool
) – Randomizes file lengths whentraj_save_len
is short snippets. Default: False.save_trajs_as (
str
) – Format for saved trajectories. “npz”, “npz-compressed”, or “traj”. Default: “npz”.
Learning Schedule
- Parameters:
epochs (
int
) – Epochs (each = one data collection + one training round). Default: 500.start_learning_at_epoch (
int
) – Number of epochs to skip before gradient updates (for replay buffer warmup). Default: 0.start_collecting_at_epoch (
int
) – Number of epochs to skip data collection (for offline→online finetune or full offline). Default: 0.train_timesteps_per_epoch (
int
) – Number of steps in each parallel env per epoch. Default: 1000.train_batches_per_epoch (
int
) – Number of training batches per epoch. Default: 1000.val_interval (
int
|None
) – How many epochs between evaluation rollouts. Default: 20.ckpt_interval (
int
|None
) – How many epochs between saving checkpoints. Default: 50.always_save_latest (
bool
) – Whether to always save the latest weights (for distributed usage). Default: True.always_load_latest (
bool
) – Whether to always load the latest weights (for distributed usage). Default: False.
Optimization
- Parameters:
batch_size (
int
) – Batch size per GPU (in sequences). Default: 24.batches_per_update (
int
) – Number of batches to accumulate gradients over before optimizer update. Default: 1.learning_rate (
float
) – Optimizer learning rate. Default: 1e-4 (defaults to AdamW).critic_loss_weight (
float
) – Weight for critic loss vs actor loss in encoders. Default: 10.lr_warmup_steps (
int
) – Number of warmup steps for learning rate scheduler. Default: 500.grad_clip (
float
) – Gradient norm clipping value. Default: 1.0.l2_coeff (
float
) – L2 regularization coefficient (AdamW). Default: 1e-3.mixed_precision (
str
) – Mixed precision mode foraccelerate
(“no”, “fp16”, “bf16”). Default: “no”.
- property DEVICE#
Return the device (cpu/gpu) that the experiment is running on.
-
always_load_latest:
bool
= False#
-
always_save_latest:
bool
= True#
-
async_env_mp_context:
str
|None
= None#
-
batch_size:
int
= 24#
-
batches_per_update:
int
= 1#
-
ckpt_base_dir:
str
#
-
ckpt_interval:
int
|None
= 50#
- collect_new_training_data()[source]#
Generate train_timesteps_per_epoch * parallel_actors timesteps of new environment interaction that will be saved to the replay buffer when the rollouts finishes.
- Return type:
None
- compute_loss(batch, log_step)[source]#
Core computation of the actor and critic RL loss terms from a
Batch
of data.- Parameters:
batch (
Batch
) – The batch of data.log_step (
bool
) – Whether to compute extra metrics for wandb logging.
- Returns:
- loss terms and any logging metrics. “Actor Loss”, “Critic Loss”, “Sequence
Length”, “Batch Size (in Timesteps)”, “Unmasked Batch Size (in Timesteps)” are always provided. Additional keys are determined by what is logged in the Agent.forward method.
- Return type:
dict
-
critic_loss_weight:
float
= 10.0#
- delete_buffer_from_disk()[source]#
Clear the replay buffer from disk (mainly for
examples/
).Calls
self.dataset.delete()
if the current process is the main process.- Return type:
None
-
dloader_workers:
int
= 6#
- edit_actor_mask(batch, actor_loss, pad_mask)[source]#
Customize the actor loss mask.
- Parameters:
batch (
Batch
) – The batch of data.actor_loss (
FloatTensor
) – The unmasked actor loss. Shape: (Batch, Length, Num Gammas, 1)pad_mask (
BoolTensor
) – The default mask. True where the sequence was not padded out of the dataloader.
- Return type:
BoolTensor
- Returns:
The mask. True where the actor loss should count, False where it should be ignored.
- edit_critic_mask(batch, critic_loss, pad_mask)[source]#
Customize the critic loss mask.
- Parameters:
batch (
Batch
) – The batch of data.critic_loss (
FloatTensor
) – The unmasked critic loss. Shape: (Batch, Length, Num Critics, Num Gammas, 1)pad_mask (
BoolTensor
) – The default mask. True where the sequence was not padded out of the dataloader.
- Return type:
BoolTensor
- Returns:
The mask. True where the critic loss should count, False where it should be ignored.
-
env_mode:
str
= 'async'#
-
epochs:
int
= 1000#
- evaluate_test(make_test_env, timesteps, render=False, save_trajs_to=None, episodes=None)[source]#
One-off evaluation of a new environment callable for testing.
- Parameters:
make_test_env (
callable
|Iterable
[callable
]) – A callable or iterable of callables that make and return a test environment. If an iterable, it must be of lengthExperiment.parallel_actors
.timesteps (
int
) – The number of timesteps to interact with each environment.render (
bool
) – Whether to render the environment. Defaults to False.save_trajs_to (
str
|None
) – The directory to save trajectories. Useful when using evaluate_test to gather demonstration data for another run. If None, no data is saved. Defaults to None.episodes (
int
|None
) – The number of episodes to interact with the environment. If provided, the loop will terminate after this many episodes have been completed OR we hit thetimesteps
limit, whichever comes first. Defaults to None.
- Returns:
A dictionary of evaluation metrics.
- Return type:
dict[str, float]
- evaluate_val()[source]#
Evaluate the current policy without exploration noise on the validation environments, and averages the performance metrics across
accelerate
processes.- Return type:
None
- exploration_wrapper_type#
alias of
EpsilonGreedy
-
force_reset_train_envs_every:
int
|None
= None#
-
grad_clip:
float
= 1.0#
-
has_dset_edit_rights:
bool
= True#
- init_checkpoints()[source]#
Create ckpts/training_states, ckpts/policy_weights, and ckpts/latest dirs
- Return type:
None
- init_dloaders()[source]#
Create pytorch dataloaders to batch trajectories in parallel.
- Return type:
DataLoader
- init_dsets()[source]#
Modifies the provided RLDataset (in place) to use important info configured by the experiment.
- Return type:
- init_envs()[source]#
Construct parallel training and validation environments.
- Returns:
- Description of the environment setup printed to the console when
Experiment.verbose is True.
- Return type:
str
- init_optimizer(policy)[source]#
Defines the optimizer.
Override to switch from AdamW.
- Return type:
Optimizer
- Returns:
- torch.optim.Optimizer in charge of updating the Agent’s parameters
(Agent.trainable_params)
- interact(envs, timesteps, hidden_state=None, render=False, save_on_done=False, episodes=None)[source]#
Main policy loop for interacting with the environment.
- Parameters:
envs – The (parallel) environments to interact with.
timesteps (
int
) – The number of timesteps to interact with each environment.hidden_state – The hidden state of the policy. If None, a fresh hidden state is initialized. Defaults to None.
render (
bool
) – Whether to render the environment. Defaults to False.save_on_done (
bool
) – If True, save completed trajectory sequences to disk as soon as they are finished. If False, wait until all rollouts are completed. Only applicable if the provided envs are configured to save rollouts to disk. Defaults to False.episodes (
int
|None
) – The number of episodes to interact with the environment. If provided, the loop will terminate after this many episodes have been completed OR we hit thetimesteps
limit, whichever comes first. Defaults to None.
- Returns:
- Objects that keep track of standard
eval stats (average returns) and any additional eval metrics the envs have been configured to record.
- Return type:
tuple[ReturnHistory, SpecialMetricHistory]
-
l2_coeff:
float
= 0.001#
- learn()[source]#
Main training loop for the experiment.
- Return type:
None
- For every epoch, we:
Load the latest policy checkpoint if
always_load_latest
is True.- Evaluate the policy on the validation set if
val_interval
is not None and the current epoch is divisible by
val_interval
.
- Evaluate the policy on the validation set if
- Collect new training data if
train_timesteps_per_epoch
is not None and the current epoch >= to
start_collecting_at_epoch
.
- Collect new training data if
- Train the policy on the training data for
train_batches_per_epoch
batches if self.dataset.ready_for_training
is True.
- Train the policy on the training data for
- Save the policy checkpoint if
ckpt_interval
is not None and the current epoch is divisible by
ckpt_interval
.
- Save the policy checkpoint if
Write the latest policy checkpoint if
always_save_latest
is True.
Experiment be configured so that processes do some or all of the above. For example, an actor process might only do steps 1, 2, and 3, while a learner process might only do steps 4, 5, and 6.
-
learning_rate:
float
= 0.0001#
- load_checkpoint(epoch, resume_training_state=True)[source]#
Load a historical checkpoint from the
ckpts
directory of this experiment.- Parameters:
epoch (
int
) – The epoch number of the checkpoint to load.resume_training_state (
bool
) – Whether to resume the entire training process (True) or only the policy weights (False). Defaults to True.
- Return type:
None
- load_checkpoint_from_path(path, is_accelerate_state=True)[source]#
Load a checkpoint from a given path.
- Parameters:
path (
str
) – Full path to the checkpoint fle to load.is_accelerate_state (
bool
) – Whether the checkpoint is a full accelerate state (True) or pytorch weights only (False). Defaults to True.
- Return type:
None
- log(metrics_dict, key)[source]#
Log a dict of metrics to the
key
panel of the wandb console alongisde current x-axis metrics.- Return type:
None
-
log_interval:
int
= 300#
-
log_to_wandb:
bool
= False#
-
lr_warmup_steps:
int
= 500#
-
make_train_env:
callable
|Iterable
[callable
]#
-
make_val_env:
callable
|Iterable
[callable
]#
-
max_seq_len:
int
#
-
mixed_precision:
str
= 'no'#
-
padded_sampling:
str
= 'none'#
-
parallel_actors:
int
= 12#
- policy_metrics(returns, specials)[source]#
Gather policy performance metrics across parallel environments.
- Parameters:
returns (
Iterable
[ReturnHistory
]) – The return history logger froms the environments.specials (
Iterable
[SpecialMetricHistory
]) – The special metrics history loggers from the environments.
- Returns:
A dictionary of policy performance metrics.
- Return type:
dict
- read_latest_policy()[source]#
Read the latest policy – used to communicate weight updates between learning/collecting processes
- Return type:
None
-
run_name:
str
#
-
sample_actions:
bool
= True#
- save_checkpoint()[source]#
Save both the training state and the policy weights to the ckpt_dir.
- Return type:
None
-
save_trajs_as:
str
= 'npz'#
-
stagger_traj_file_lengths:
bool
= True#
- start()[source]#
Manual initialization after __init__ to give time for gin configuration.
Call before Experiment.learn()
-
start_collecting_at_epoch:
int
= 0#
-
start_learning_at_epoch:
int
= 0#
-
train_batches_per_epoch:
int
= 1000#
- train_step(batch, log_step)[source]#
Take a single training step on a
batch
of data.- Parameters:
batch (
Batch
) – The batch of data.log_step (
bool
) – Whether to compute extra metrics for wandb logging.
- Returns:
metrics from the compute_loss method.
- Return type:
dict
-
train_timesteps_per_epoch:
int
= 1000#
-
traj_encoder_type:
type
[TrajEncoder
]#
-
traj_save_len:
int
= 10000000000.0#
-
tstep_encoder_type:
type
[TstepEncoder
]#
-
val_interval:
int
|None
= 20#
-
val_timesteps_per_epoch:
int
#
-
verbose:
bool
= True#
-
wandb_entity:
str
= "os.environ['AMAGO_WANDB_ENTITY']"#
-
wandb_group_name:
str
= None#
-
wandb_project:
str
= "os.environ['AMAGO_WANDB_PROJECT']"#
- write_latest_policy()[source]#
Write absolute latest policy to a hardcoded location used by
read_latest_policy
- Return type:
None