1import amago
2from amago.envs.builtin.toy_gym import MetaFrozenLake
3from amago.envs import AMAGOEnv
4from amago.loading import DiskTrajDataset
5from amago.cli_utils import *
6
7
8def add_cli(parser):
9 parser.add_argument(
10 "--seq_model",
11 type=str,
12 choices=["ff", "transformer", "rnn", "mamba"],
13 required=True,
14 )
15 parser.add_argument("--run_name", type=str, required=True)
16 parser.add_argument("--buffer_dir", type=str, required=True)
17 parser.add_argument("--log", action="store_true")
18 parser.add_argument("--trials", type=int, default=1)
19 parser.add_argument("--lake_size", type=int, default=5)
20 parser.add_argument("--k_episodes", type=int, default=15)
21 parser.add_argument("--hard_mode", action="store_true")
22 parser.add_argument("--recover_mode", action="store_true")
23 parser.add_argument("--max_rollout_length", type=int, default=512)
24 parser.add_argument("--max_seq_len", type=int, default=512)
25 return parser
26
27
28if __name__ == "__main__":
29 parser = ArgumentParser()
30 add_cli(parser)
31 args = parser.parse_args()
32
33 if args.log:
34 import wandb
35
36 config = {}
37 # configure trajectory encoder (seq2seq memory model)
38 traj_encoder_type = switch_traj_encoder(
39 config,
40 arch=args.seq_model,
41 memory_size=128,
42 layers=3,
43 )
44 # configure timestep encoder
45 tstep_encoder_type = switch_tstep_encoder(
46 config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
47 )
48
49 # we're using the default exploration strategy but being overly verbose about it for the example
50 exploration_wrapper_type = switch_exploration(
51 config,
52 strategy="egreedy",
53 eps_start=1.0,
54 eps_end=0.05,
55 steps_anneal=1_000_000,
56 randomize_eps=True,
57 )
58 use_config(config)
59
60 group_name = f"{args.run_name}_{args.seq_model}"
61 for trial in range(args.trials):
62 run_name = group_name + f"_trial_{trial}"
63
64 # create a dataset on disk. envs will write finished episodes here
65 dset = DiskTrajDataset(
66 dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
67 )
68 # save checkpoints alongside the buffer
69 ckpt_dir = args.buffer_dir
70
71 # wrap environment
72 make_env = lambda: AMAGOEnv(
73 MetaFrozenLake(
74 k_episodes=args.k_episodes,
75 size=args.lake_size,
76 hard_mode=args.hard_mode,
77 recover_mode=args.recover_mode,
78 ),
79 env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
80 + ("_hard" if args.hard_mode else "_easy")
81 + ("_recover" if args.recover_mode else "_reset"),
82 )
83
84 # create `Experiment`
85 experiment = amago.Experiment(
86 make_train_env=make_env,
87 make_val_env=make_env,
88 max_seq_len=args.max_seq_len,
89 traj_save_len=args.max_rollout_length,
90 dataset=dset,
91 ckpt_base_dir=ckpt_dir,
92 agent_type=amago.agent.Agent,
93 exploration_wrapper_type=exploration_wrapper_type,
94 tstep_encoder_type=tstep_encoder_type,
95 traj_encoder_type=traj_encoder_type,
96 run_name=run_name,
97 dloader_workers=10,
98 log_to_wandb=args.log,
99 wandb_group_name=group_name,
100 epochs=700 if not args.hard_mode else 900,
101 parallel_actors=24,
102 train_timesteps_per_epoch=350,
103 train_batches_per_epoch=800,
104 val_interval=20,
105 val_timesteps_per_epoch=args.max_rollout_length * 2,
106 ckpt_interval=50,
107 env_mode="sync",
108 )
109
110 # start experiment (build envs, policies, etc.)
111 experiment.start()
112 # run training
113 experiment.learn()
114 experiment.evaluate_test(make_env, timesteps=10_000)
115 experiment.delete_buffer_from_disk()
116 wandb.finish()