1from argparse import ArgumentParser
2
3import amago
4from amago.envs.builtin.toy_gym import MetaFrozenLake
5from amago.envs import AMAGOEnv
6from amago.loading import DiskTrajDataset
7from amago import cli_utils
8
9
10def add_cli(parser):
11 parser.add_argument(
12 "--seq_model",
13 type=str,
14 choices=["ff", "transformer", "rnn", "mamba"],
15 required=True,
16 )
17 parser.add_argument("--run_name", type=str, required=True)
18 parser.add_argument("--buffer_dir", type=str, required=True)
19 parser.add_argument("--log", action="store_true")
20 parser.add_argument("--trials", type=int, default=1)
21 parser.add_argument("--lake_size", type=int, default=5)
22 parser.add_argument("--k_episodes", type=int, default=10)
23 parser.add_argument("--hard_mode", action="store_true")
24 parser.add_argument("--recover_mode", action="store_true")
25 parser.add_argument("--slip_chance", type=float, default=0.0)
26 parser.add_argument(
27 "--max_episode_steps",
28 type=int,
29 default=None,
30 help="Max steps per attempt. Default: N² (standard) or 2*N² (hard).",
31 )
32 parser.add_argument(
33 "--hide_k_progress",
34 action="store_true",
35 help="Hide current_k/k_episodes from observations (for length extrapolation tests).",
36 )
37 parser.add_argument(
38 "--max_seq_len",
39 type=int,
40 default=None,
41 help="Training sequence length. Default: max_episode_steps * k_episodes (full trajectory).",
42 )
43 return parser
44
45
46if __name__ == "__main__":
47 parser = ArgumentParser()
48 add_cli(parser)
49 args = parser.parse_args()
50
51 if args.log:
52 import wandb
53
54 lake_kwargs = dict(
55 size=args.lake_size,
56 k_episodes=args.k_episodes,
57 hard_mode=args.hard_mode,
58 recover_mode=args.recover_mode,
59 max_episode_steps=args.max_episode_steps,
60 show_k_progress=not args.hide_k_progress,
61 slip_chance=args.slip_chance,
62 )
63 max_ep_steps = MetaFrozenLake(**lake_kwargs).max_episode_steps
64 max_rollout_length = max_ep_steps * args.k_episodes
65 max_seq_len = args.max_seq_len or max_rollout_length
66
67 config = {}
68 # configure trajectory encoder (seq2seq memory model)
69 traj_encoder_type = cli_utils.switch_traj_encoder(
70 config,
71 arch=args.seq_model,
72 memory_size=128,
73 layers=3,
74 )
75 # configure timestep encoder
76 tstep_encoder_type = cli_utils.switch_tstep_encoder(
77 config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
78 )
79 # we're using the default exploration strategy but being overly verbose about it for the example
80 exploration_wrapper_type = cli_utils.switch_exploration(
81 config,
82 strategy="egreedy",
83 eps_start=1.0,
84 eps_end=0.05,
85 steps_anneal=1_000_000,
86 randomize_eps=True,
87 )
88 agent_type = cli_utils.switch_agent(config, "agent", tau=0.004)
89 cli_utils.use_config(config)
90
91 group_name = f"{args.run_name}_{args.seq_model}"
92 for trial in range(args.trials):
93 run_name = group_name + f"_trial_{trial}"
94
95 # create a dataset on disk. envs will write finished episodes here
96 dset = DiskTrajDataset(
97 dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
98 )
99 # save checkpoints alongside the buffer
100 ckpt_dir = args.buffer_dir
101 # wrap environment
102 make_env = lambda: AMAGOEnv(
103 MetaFrozenLake(**lake_kwargs),
104 env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
105 + ("_hard" if args.hard_mode else "_easy")
106 + ("_recover" if args.recover_mode else "_reset"),
107 )
108
109 experiment = amago.Experiment(
110 make_train_env=make_env,
111 make_val_env=make_env,
112 max_seq_len=max_seq_len,
113 traj_save_len=max_rollout_length,
114 dataset=dset,
115 ckpt_base_dir=args.buffer_dir,
116 agent_type=agent_type,
117 exploration_wrapper_type=exploration_wrapper_type,
118 tstep_encoder_type=tstep_encoder_type,
119 traj_encoder_type=traj_encoder_type,
120 run_name=run_name,
121 dloader_workers=10,
122 log_to_wandb=args.log,
123 wandb_group_name=group_name,
124 epochs=700 if not args.hard_mode else 900,
125 parallel_actors=32,
126 train_timesteps_per_epoch=max_rollout_length,
127 train_batches_per_epoch=1000,
128 val_interval=20,
129 val_timesteps_per_epoch=max_rollout_length * 2,
130 ckpt_interval=200,
131 env_mode="sync",
132 )
133
134 experiment.start()
135 experiment.learn()
136 experiment.evaluate_test(make_env, timesteps=10_000)
137 experiment.delete_buffer_from_disk()
138 if args.log:
139 wandb.finish()