1from argparse import ArgumentParser
2
3import wandb
4
5import amago
6from amago.envs import AMAGOEnv
7from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
8from amago import cli_utils
9
10
11def add_cli(parser):
12 parser.add_argument(
13 "--policy_seq_len", type=int, default=32, help="Policy sequence length."
14 )
15 parser.add_argument(
16 "--eval_episodes_per_actor",
17 type=int,
18 default=1,
19 help="Validation episodes per parallel actor.",
20 )
21 parser.add_argument(
22 "--task_min_velocity",
23 type=float,
24 default=0.0,
25 help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
26 )
27 parser.add_argument(
28 "--task_max_velocity",
29 type=float,
30 default=3.0,
31 help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
32 )
33 parser.add_argument(
34 "--inner_episode_steps",
35 type=int,
36 default=200,
37 help="Step horizon of each inner episode. Default 200 (combined with the default --k_train_episodes=3) keeps total trial length at 600 steps. Set 1000 with --k_train_episodes=1 to recover the unwrapped task.",
38 )
39 parser.add_argument(
40 "--k_train_episodes",
41 type=int,
42 default=3,
43 help="Inner episodes per meta-trial during training. Default 3 makes the env a true meta-RL trial: a single hidden target velocity persists across 3 inner episodes (soft resets between). Set 1 to recover the unwrapped task.",
44 )
45 parser.add_argument(
46 "--k_eval_episodes",
47 type=int,
48 default=None,
49 help="Inner episodes per meta-trial at eval time. Defaults to --k_train_episodes. Larger values (e.g. 10) probe how well the agent keeps adapting beyond its training horizon.",
50 )
51 return parser
52
53
54"""
55Because this task is so similar to the other gymnasium examples, this example script is overly
56verbose about showing how you could customize the environment and create a train/test split.
57
58If you don't edit anything, this only becomes a longer way to train/test on the default task
59distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
60"""
61
62
63class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
64 def sample_target_velocity(self) -> float:
65 # be sure to use `random` or be careful about np default_rng to ensure
66 # tasks are different across async parallel actors!
67 vel = super().sample_target_velocity() # random.uniform(min_vel, max_vel)
68 return vel
69
70
71class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
72 def sample_target_velocity(self) -> float:
73 vel = super().sample_target_velocity()
74 # or, to create OOD eval tasks:
75 # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
76 # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
77 return vel
78
79
80if __name__ == "__main__":
81 parser = ArgumentParser()
82 cli_utils.add_common_cli(parser)
83 add_cli(parser)
84 args = parser.parse_args()
85
86 k_eval = (
87 args.k_eval_episodes
88 if args.k_eval_episodes is not None
89 else args.k_train_episodes
90 )
91
92 def make_train_env():
93 return AMAGOEnv(
94 MyCustomHalfCheetahTrain(
95 task_min_velocity=args.task_min_velocity,
96 task_max_velocity=args.task_max_velocity,
97 max_episode_steps=args.inner_episode_steps,
98 k_episodes=args.k_train_episodes,
99 ),
100 env_name="HalfCheetahV4Velocity",
101 )
102
103 def make_val_env():
104 return AMAGOEnv(
105 MyCustomHalfCheetahEval(
106 task_min_velocity=args.task_min_velocity,
107 task_max_velocity=args.task_max_velocity,
108 max_episode_steps=args.inner_episode_steps,
109 k_episodes=k_eval,
110 ),
111 env_name="HalfCheetahV4Velocity",
112 )
113
114 config = {
115 "amago.nets.traj_encoders.TformerTrajEncoder.pos_emb": "rope",
116 }
117 # switch sequence model
118 traj_encoder_type = cli_utils.switch_traj_encoder(
119 config,
120 arch=args.traj_encoder,
121 memory_size=args.memory_size,
122 layers=args.memory_layers,
123 )
124 # switch agent
125 agent_type = cli_utils.switch_agent(
126 config,
127 args.agent_type,
128 reward_multiplier=1.0, # gym locomotion returns are already large
129 gamma=0.99, # locomotion policies don't need long horizons - fall back to the default
130 tau=0.005,
131 )
132 # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
133 # noise from most TD3/DPPG implementations.
134 exploration_type = cli_utils.switch_exploration(
135 config, "egreedy", steps_anneal=500_000
136 )
137 cli_utils.use_config(config, args.configs)
138
139 group_name = args.run_name
140 for trial in range(args.trials):
141 run_name = group_name + f"_trial_{trial}"
142 experiment = cli_utils.create_experiment_from_cli(
143 args,
144 make_train_env=make_train_env, # different train/val envs
145 make_val_env=make_val_env,
146 max_seq_len=args.policy_seq_len,
147 traj_save_len=args.policy_seq_len * 6,
148 run_name=run_name,
149 tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
150 traj_encoder_type=traj_encoder_type,
151 exploration_wrapper_type=exploration_type,
152 agent_type=agent_type,
153 group_name=group_name,
154 val_timesteps_per_epoch=args.eval_episodes_per_actor
155 * (args.inner_episode_steps * k_eval + 1),
156 grad_clip=2.0,
157 learning_rate=3e-4,
158 )
159 experiment.start()
160 if args.ckpt is not None:
161 experiment.load_checkpoint(args.ckpt)
162 experiment.learn()
163 experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
164 experiment.delete_buffer_from_disk()
165 wandb.finish()