1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs.builtin.toy_gym import RoomKeyDoor
6from amago.envs import AMAGOEnv
7from amago import cli_utils
8
9
10def add_cli(parser):
11 parser.add_argument(
12 "--k_episodes",
13 type=int,
14 default=8,
15 help="Number of episodes per meta-rollout. Effective sequence length = k_episodes * episode_length.",
16 )
17 parser.add_argument(
18 "--room_size",
19 type=int,
20 default=8,
21 help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
22 )
23 parser.add_argument(
24 "--episode_length",
25 type=int,
26 default=50,
27 help="Maximum length of a single episode in the environment.",
28 )
29 parser.add_argument(
30 "--light_room_observation",
31 action="store_true",
32 help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
33 )
34 parser.add_argument(
35 "--randomize_actions",
36 action="store_true",
37 help="Randomize the agent's action space to make the task harder.",
38 )
39 parser.add_argument(
40 "--finite_horizon",
41 action="store_true",
42 help="Use finite-horizon mode: include time in observations and signal meta-done as terminated. Default is infinite-horizon (no time in obs, meta-done as truncated).",
43 )
44 return parser
45
46
47if __name__ == "__main__":
48 parser = ArgumentParser()
49 cli_utils.add_common_cli(parser)
50 add_cli(parser)
51 args = parser.parse_args()
52
53 config = {}
54 tstep_encoder_type = cli_utils.switch_tstep_encoder(
55 config,
56 arch="ff",
57 n_layers=2,
58 d_hidden=128,
59 d_output=64,
60 specify_obs_keys=["observed", "prev_action", "prev_reward"],
61 hide_rl2s=True,
62 normalize_inputs=False,
63 )
64 traj_encoder_type = cli_utils.switch_traj_encoder(
65 config,
66 arch=args.traj_encoder,
67 memory_size=args.memory_size,
68 layers=args.memory_layers,
69 pos_emb="rope",
70 )
71 agent_type = cli_utils.switch_agent(
72 config, args.agent_type, reward_multiplier=100.0
73 )
74 horizon_type = "finite" if args.finite_horizon else "infinite"
75 dummy_env = RoomKeyDoor(
76 size=args.room_size,
77 max_episode_steps=args.episode_length,
78 k_episodes=args.k_episodes,
79 horizon_type=horizon_type,
80 )
81 meta_horizon = dummy_env.meta_horizon
82 args.timesteps_per_epoch = meta_horizon
83 # the fancier exploration schedule mentioned in the appendix can help
84 # when the domain is a true meta-RL problem and the "horizon" time limit
85 # (above) is actually relevant for resetting the task.
86 exploration_type = cli_utils.switch_exploration(
87 config, "bilevel", steps_anneal=500_000, rollout_horizon=meta_horizon
88 )
89 cli_utils.use_config(config, args.configs)
90
91 group_name = f"{args.run_name}_dark_key_door"
92 for trial in range(args.trials):
93 run_name = group_name + f"_trial_{trial}"
94 make_train_env = lambda: AMAGOEnv(
95 env=RoomKeyDoor(
96 size=args.room_size,
97 max_episode_steps=args.episode_length,
98 k_episodes=args.k_episodes,
99 dark=not args.light_room_observation,
100 randomize_actions=args.randomize_actions,
101 horizon_type=horizon_type,
102 ),
103 env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}-{horizon_type}",
104 )
105 experiment = cli_utils.create_experiment_from_cli(
106 args,
107 agent_type=agent_type,
108 tstep_encoder_type=tstep_encoder_type,
109 traj_encoder_type=traj_encoder_type,
110 make_train_env=make_train_env,
111 make_val_env=make_train_env,
112 max_seq_len=meta_horizon,
113 traj_save_len=meta_horizon * 10,
114 group_name=group_name,
115 run_name=run_name,
116 val_timesteps_per_epoch=meta_horizon * 4,
117 exploration_wrapper_type=exploration_type,
118 stagger_traj_file_lengths=False,
119 wandb_project="z-room-key-door",
120 )
121 experiment = cli_utils.switch_async_mode(experiment, args.mode)
122 experiment.start()
123 if args.ckpt is not None:
124 experiment.load_checkpoint(args.ckpt)
125 experiment.learn()
126 experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
127 experiment.delete_buffer_from_disk()
128 wandb.finish()