1from argparse import ArgumentParser
2import math
3import random
4
5import wandb
6
7import amago
8from amago.envs import AMAGOEnv
9from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
10from amago import cli_utils
11
12
13def add_cli(parser):
14 parser.add_argument(
15 "--policy_seq_len", type=int, default=32, help="Policy sequence length."
16 )
17 parser.add_argument(
18 "--eval_episodes_per_actor",
19 type=int,
20 default=1,
21 help="Validation episodes per parallel actor.",
22 )
23 parser.add_argument(
24 "--task_min_velocity",
25 type=float,
26 default=0.0,
27 help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
28 )
29 parser.add_argument(
30 "--task_max_velocity",
31 type=float,
32 default=3.0,
33 help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
34 )
35 return parser
36
37
38"""
39Because this task is so similar to the other gymnasium examples, this example script is overly
40verbose about showing how you could customize the environment and create a train/test split.
41
42If you don't edit anything, this only becomes a longer way to train/test on the default task
43distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
44"""
45
46
47class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
48 def sample_target_velocity(self) -> float:
49 # be sure to use `random` or be careful about np default_rng to ensure
50 # tasks are different across async parallel actors!
51 vel = super().sample_target_velocity() # random.uniform(min_vel, max_vel)
52 return vel
53
54
55class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
56 def sample_target_velocity(self) -> float:
57 vel = super().sample_target_velocity()
58 # or, to create OOD eval tasks:
59 # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
60 # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
61 return vel
62
63
64class AMAGOEnvWithVelocityName(AMAGOEnv):
65 """
66 Every eval metric gets logged based on the current
67 `env_name`. You could use this to log metrics for
68 different tasks separately. They get averaged over
69 all the evals with the same name, so you want a discrete
70 number of names that will get sample sizes > 1.
71 """
72
73 @property
74 def env_name(self) -> str:
75 current_task_vel = self.env.unwrapped.target_velocity
76 # need to discretize this somehow; just one example
77 low, high = math.floor(current_task_vel), math.ceil(current_task_vel)
78 return f"HalfCheetahVelocity-Vel-[{low}, {high}]"
79
80
81if __name__ == "__main__":
82 parser = ArgumentParser()
83 cli_utils.add_common_cli(parser)
84 add_cli(parser)
85 args = parser.parse_args()
86
87 # setup environment
88 make_train_env = lambda: AMAGOEnvWithVelocityName(
89 MyCustomHalfCheetahTrain(
90 task_min_velocity=args.task_min_velocity,
91 task_max_velocity=args.task_max_velocity,
92 ),
93 # the env_name is totally arbitrary and only impacts logging / data filenames
94 env_name=f"HalfCheetahV4Velocity",
95 )
96
97 make_val_env = lambda: AMAGOEnvWithVelocityName(
98 MyCustomHalfCheetahEval(
99 task_min_velocity=args.task_min_velocity,
100 task_max_velocity=args.task_max_velocity,
101 ),
102 # this would get replaced by the env_name property
103 # defined above.
104 env_name=f"HalfCheetahV4VelocityEval",
105 )
106
107 config = {}
108 # switch sequence model
109 traj_encoder_type = cli_utils.switch_traj_encoder(
110 config,
111 arch=args.traj_encoder,
112 memory_size=args.memory_size,
113 layers=args.memory_layers,
114 )
115 # switch agent
116 agent_type = cli_utils.switch_agent(
117 config,
118 args.agent_type,
119 reward_multiplier=1.0, # gym locomotion returns are already large
120 gamma=0.99, # locomotion policies don't need long horizons - fall back to the default
121 tau=0.005,
122 )
123 # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
124 # noise from most TD3/DPPG implementations.
125 exploration_type = cli_utils.switch_exploration(
126 config, "egreedy", steps_anneal=500_000
127 )
128 cli_utils.use_config(config, args.configs)
129
130 group_name = args.run_name
131 for trial in range(args.trials):
132 run_name = group_name + f"_trial_{trial}"
133 experiment = cli_utils.create_experiment_from_cli(
134 args,
135 make_train_env=make_train_env, # different train/val envs
136 make_val_env=make_val_env,
137 max_seq_len=args.policy_seq_len,
138 traj_save_len=args.policy_seq_len * 6,
139 run_name=run_name,
140 tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
141 traj_encoder_type=traj_encoder_type,
142 exploration_wrapper_type=exploration_type,
143 agent_type=agent_type,
144 group_name=group_name,
145 val_timesteps_per_epoch=args.eval_episodes_per_actor * 1001,
146 grad_clip=2.0,
147 learning_rate=3e-4,
148 )
149 experiment.start()
150 if args.ckpt is not None:
151 experiment.load_checkpoint(args.ckpt)
152 experiment.learn()
153 experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
154 experiment.delete_buffer_from_disk()
155 wandb.finish()