1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs.builtin.toy_gym import RoomKeyDoor
6from amago.envs import AMAGOEnv
7from amago import cli_utils
8
9
10def add_cli(parser):
11 parser.add_argument(
12 "--meta_horizon",
13 type=int,
14 default=500,
15 help="Total meta-adaptation timestep budget for the agent to explore the same room layout.",
16 )
17 parser.add_argument(
18 "--room_size",
19 type=int,
20 default=8,
21 help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
22 )
23 parser.add_argument(
24 "--episode_length",
25 type=int,
26 default=50,
27 help="Maximum length of a single episode in the environment.",
28 )
29 parser.add_argument(
30 "--light_room_observation",
31 action="store_true",
32 help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
33 )
34 parser.add_argument(
35 "--randomize_actions",
36 action="store_true",
37 help="Randomize the agent's action space to make the task harder.",
38 )
39 return parser
40
41
42if __name__ == "__main__":
43 parser = ArgumentParser()
44 cli_utils.add_common_cli(parser)
45 add_cli(parser)
46 args = parser.parse_args()
47
48 config = {}
49 tstep_encoder_type = cli_utils.switch_tstep_encoder(
50 config, arch="ff", n_layers=2, d_hidden=128, d_output=64
51 )
52 traj_encoder_type = cli_utils.switch_traj_encoder(
53 config,
54 arch=args.traj_encoder,
55 memory_size=args.memory_size,
56 layers=args.memory_layers,
57 )
58 agent_type = cli_utils.switch_agent(
59 config, args.agent_type, reward_multiplier=100.0
60 )
61 # the fancier exploration schedule mentioned in the appendix can help
62 # when the domain is a true meta-RL problem and the "horizon" time limit
63 # (above) is actually relevant for resetting the task.
64 exploration_type = cli_utils.switch_exploration(
65 config, "bilevel", steps_anneal=500_000, rollout_horizon=args.meta_horizon
66 )
67 cli_utils.use_config(config, args.configs)
68
69 group_name = f"{args.run_name}_dark_key_door"
70 for trial in range(args.trials):
71 run_name = group_name + f"_trial_{trial}"
72 make_train_env = lambda: AMAGOEnv(
73 env=RoomKeyDoor(
74 size=args.room_size,
75 max_episode_steps=args.episode_length,
76 meta_rollout_horizon=args.meta_horizon,
77 dark=not args.light_room_observation,
78 randomize_actions=args.randomize_actions,
79 ),
80 env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}",
81 )
82 experiment = cli_utils.create_experiment_from_cli(
83 args,
84 agent_type=agent_type,
85 tstep_encoder_type=tstep_encoder_type,
86 traj_encoder_type=traj_encoder_type,
87 make_train_env=make_train_env,
88 make_val_env=make_train_env,
89 max_seq_len=args.meta_horizon,
90 traj_save_len=args.meta_horizon,
91 group_name=group_name,
92 run_name=run_name,
93 val_timesteps_per_epoch=args.meta_horizon * 4,
94 exploration_wrapper_type=exploration_type,
95 )
96 experiment = cli_utils.switch_async_mode(experiment, args.mode)
97 experiment.start()
98 if args.ckpt is not None:
99 experiment.load_checkpoint(args.ckpt)
100 experiment.learn()
101 experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
102 experiment.delete_buffer_from_disk()
103 wandb.finish()