1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs import AMAGOEnv
6from amago.envs.builtin.tmaze import TMazeAltPassive, TMazeAltActive
7from amago.envs.exploration import EpsilonGreedy
8from amago.cli_utils import *
9
10
11def add_cli(parser):
12 parser.add_argument("--horizon", type=int, required=True)
13 return parser
14
15
16@gin.configurable
17class TMazeExploration(EpsilonGreedy):
18 """
19 The Tmaze environment is meant to evaluate recall over long context lengths without
20 testing exploration, but it does this by requiring horizon - 1 deterministic actions
21 to create a gap between the timestep that reveals the correct action and the timestep
22 it is taken. This unintentionally creates a worst-case scenario for epsilon greedy
23 exploration. We use this epsilon greedy exploration schedule to answer the
24 central memory question while fixing the sample efficiency problems it creates.
25
26 https://github.com/twni2016/Memory-RL/issues/1
27 """
28
29 def __init__(
30 self,
31 env,
32 start_window=0,
33 end_window=3,
34 horizon: int = gin.REQUIRED,
35 eps_start=1.0,
36 eps_end=0.01,
37 steps_anneal=100_000,
38 ):
39 self.start_window = start_window
40 self.end_window = end_window
41 self.horizon = horizon
42 super().__init__(
43 env,
44 eps_start=eps_start,
45 eps_end=eps_end,
46 steps_anneal=steps_anneal,
47 )
48
49 def current_eps(self, local_step: int):
50 current = super().current_eps(local_step)
51 if (
52 local_step > self.start_window
53 and local_step < self.horizon - self.end_window
54 ):
55 # low exploration during the easy corridor section; regular during the
56 # interesting early and late timesteps.
57 current[:] = 0.5 / self.horizon
58 return current
59
60
61if __name__ == "__main__":
62 parser = ArgumentParser()
63 add_common_cli(parser)
64 add_cli(parser)
65 args = parser.parse_args()
66
67 config = {
68 "TMazeExploration.horizon": args.horizon,
69 }
70 traj_encoder_type = switch_traj_encoder(
71 config,
72 arch=args.traj_encoder,
73 memory_size=args.memory_size,
74 layers=args.memory_layers,
75 )
76 tstep_encoder_type = switch_tstep_encoder(
77 config,
78 arch="ff",
79 n_layers=2,
80 d_hidden=128,
81 d_output=128,
82 normalize_inputs=False,
83 )
84 agent_type = switch_agent(
85 config, args.agent_type, reward_multiplier=100.0, gamma=0.9999
86 )
87 use_config(config, args.configs)
88
89 group_name = f"{args.run_name}_TMazePassive_H{args.horizon}"
90 for trial in range(args.trials):
91 run_name = group_name + f"_trial_{trial}"
92 make_env = lambda: AMAGOEnv(
93 env=TMazeAltPassive(
94 corridor_length=args.horizon, penalty=-1.0 / args.horizon
95 ),
96 env_name=f"TMazePassive-H{args.horizon}",
97 )
98 experiment = create_experiment_from_cli(
99 args,
100 make_train_env=make_env,
101 make_val_env=make_env,
102 max_seq_len=args.horizon + 1,
103 traj_save_len=args.horizon + 1,
104 group_name=group_name,
105 run_name=run_name,
106 tstep_encoder_type=tstep_encoder_type,
107 traj_encoder_type=traj_encoder_type,
108 agent_type=agent_type,
109 val_timesteps_per_epoch=args.horizon + 1,
110 sample_actions=False, # even softmax prob .999 isn't good enough for this env...
111 exploration_wrapper_type=TMazeExploration,
112 )
113 switch_async_mode(experiment, args.mode)
114 experiment.start()
115 if args.ckpt is not None:
116 experiment.load_checkpoint(args.ckpt)
117 experiment.learn()
118 experiment.evaluate_test(make_env, timesteps=args.horizon * 5, render=False)
119 experiment.delete_buffer_from_disk()
120 wandb.finish()