1from argparse import ArgumentParser
2
3import amago
4from amago.envs.builtin.toy_gym import MetaFrozenLake
5from amago.envs import AMAGOEnv
6from amago.loading import DiskTrajDataset
7from amago import cli_utils
8
9
10def add_cli(parser):
11 parser.add_argument(
12 "--seq_model",
13 type=str,
14 choices=["ff", "transformer", "rnn", "mamba"],
15 required=True,
16 )
17 parser.add_argument("--run_name", type=str, required=True)
18 parser.add_argument("--buffer_dir", type=str, required=True)
19 parser.add_argument("--log", action="store_true")
20 parser.add_argument("--trials", type=int, default=1)
21 parser.add_argument("--lake_size", type=int, default=5)
22 parser.add_argument("--k_episodes", type=int, default=15)
23 parser.add_argument("--hard_mode", action="store_true")
24 parser.add_argument("--recover_mode", action="store_true")
25 parser.add_argument("--max_rollout_length", type=int, default=512)
26 parser.add_argument("--max_seq_len", type=int, default=512)
27 return parser
28
29
30if __name__ == "__main__":
31 parser = ArgumentParser()
32 add_cli(parser)
33 args = parser.parse_args()
34
35 if args.log:
36 import wandb
37
38 config = {}
39 # configure trajectory encoder (seq2seq memory model)
40 traj_encoder_type = cli_utils.switch_traj_encoder(
41 config,
42 arch=args.seq_model,
43 memory_size=128,
44 layers=3,
45 )
46 # configure timestep encoder
47 tstep_encoder_type = cli_utils.switch_tstep_encoder(
48 config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
49 )
50
51 # we're using the default exploration strategy but being overly verbose about it for the example
52 exploration_wrapper_type = cli_utils.switch_exploration(
53 config,
54 strategy="egreedy",
55 eps_start=1.0,
56 eps_end=0.05,
57 steps_anneal=1_000_000,
58 randomize_eps=True,
59 )
60 cli_utils.use_config(config)
61
62 group_name = f"{args.run_name}_{args.seq_model}"
63 for trial in range(args.trials):
64 run_name = group_name + f"_trial_{trial}"
65
66 # create a dataset on disk. envs will write finished episodes here
67 dset = DiskTrajDataset(
68 dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
69 )
70 # save checkpoints alongside the buffer
71 ckpt_dir = args.buffer_dir
72
73 # wrap environment
74 make_env = lambda: AMAGOEnv(
75 MetaFrozenLake(
76 k_episodes=args.k_episodes,
77 size=args.lake_size,
78 hard_mode=args.hard_mode,
79 recover_mode=args.recover_mode,
80 ),
81 env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
82 + ("_hard" if args.hard_mode else "_easy")
83 + ("_recover" if args.recover_mode else "_reset"),
84 )
85
86 # create `Experiment`
87 experiment = amago.Experiment(
88 make_train_env=make_env,
89 make_val_env=make_env,
90 max_seq_len=args.max_seq_len,
91 traj_save_len=args.max_rollout_length,
92 dataset=dset,
93 ckpt_base_dir=ckpt_dir,
94 agent_type=amago.agent.Agent,
95 exploration_wrapper_type=exploration_wrapper_type,
96 tstep_encoder_type=tstep_encoder_type,
97 traj_encoder_type=traj_encoder_type,
98 run_name=run_name,
99 dloader_workers=10,
100 log_to_wandb=args.log,
101 wandb_group_name=group_name,
102 epochs=700 if not args.hard_mode else 900,
103 parallel_actors=24,
104 train_timesteps_per_epoch=350,
105 train_batches_per_epoch=800,
106 val_interval=20,
107 val_timesteps_per_epoch=args.max_rollout_length * 2,
108 ckpt_interval=50,
109 env_mode="sync",
110 )
111
112 # start experiment (build envs, policies, etc.)
113 experiment.start()
114 # run training
115 experiment.learn()
116 experiment.evaluate_test(make_env, timesteps=10_000)
117 experiment.delete_buffer_from_disk()
118 wandb.finish()