Create a Dataset#
AMAGO trains on sequence data loaded from an RLDataset that inherits from the pytorch Dataset.
Standard online RL can just use DiskTrajDataset,
which tells the envs where to save sequences and deletes the oldest data when full (like a normal replay buffer).
from amago.loading import DiskTrajDataset
dataset = DiskTrajDataset(
dset_root="plenty_of_space",
dset_name="give_this_replay_buffer_a_name",
dset_max_size=10_000, # measured in *sequences*
)
# creates a directory sturcture like:
# dset_root/
# dset_name/
# buffer/
# protected/
# optional place to move data you want to sample from but never delete
# fifo/
# envs write files here and dset deletes them when full
experiment = amago.Experiment(
...,
dataset=dataset,
# optional control over the way all datasets sample from seqs longer than the policy's max input length
padded_sampling="none",
# optional control over the way envs write to the dataset:
traj_save_len=1000, # write sequences after this many timesteps even if the episode hasn't finished
)