1from argparse import ArgumentParser
2import math
3import random
4
5import wandb
6
7import amago
8from amago.envs import AMAGOEnv
9from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
10from amago import cli_utils
11
12
13def add_cli(parser):
14 parser.add_argument(
15 "--policy_seq_len", type=int, default=32, help="Policy sequence length."
16 )
17 parser.add_argument(
18 "--eval_episodes_per_actor",
19 type=int,
20 default=1,
21 help="Validation episodes per parallel actor.",
22 )
23 parser.add_argument(
24 "--task_min_velocity",
25 type=float,
26 default=0.0,
27 help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
28 )
29 parser.add_argument(
30 "--task_max_velocity",
31 type=float,
32 default=3.0,
33 help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
34 )
35 return parser
36
37
38"""
39Because this task is so similar to the other gymnasium examples, this example script is overly
40verbose about showing how you could customize the environment and create a train/test split.
41
42If you don't edit anything, this only becomes a longer way to train/test on the default task
43distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
44"""
45
46
47class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
48
49 def sample_target_velocity(self) -> float:
50 # be sure to use `random` or be careful about np default_rng to ensure
51 # tasks are different across async parallel actors!
52 vel = super().sample_target_velocity() # random.uniform(min_vel, max_vel)
53 return vel
54
55
56class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
57
58 def sample_target_velocity(self) -> float:
59 vel = super().sample_target_velocity()
60 # or, to create OOD eval tasks:
61 # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
62 # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
63 return vel
64
65
66class AMAGOEnvWithVelocityName(AMAGOEnv):
67 """
68 Every eval metric gets logged based on the current
69 `env_name`. You could use this to log metrics for
70 different tasks separately. They get averaged over
71 all the evals with the same name, so you want a discrete
72 number of names that will get sample sizes > 1.
73 """
74
75 @property
76 def env_name(self) -> str:
77 current_task_vel = self.env.unwrapped.target_velocity
78 # need to discretize this somehow; just one example
79 low, high = math.floor(current_task_vel), math.ceil(current_task_vel)
80 return f"HalfCheetahVelocity-Vel-[{low}, {high}]"
81
82
83if __name__ == "__main__":
84 parser = ArgumentParser()
85 cli_utils.add_common_cli(parser)
86 add_cli(parser)
87 args = parser.parse_args()
88
89 # setup environment
90 make_train_env = lambda: AMAGOEnvWithVelocityName(
91 MyCustomHalfCheetahTrain(
92 task_min_velocity=args.task_min_velocity,
93 task_max_velocity=args.task_max_velocity,
94 ),
95 # the env_name is totally arbitrary and only impacts logging / data filenames
96 env_name=f"HalfCheetahV4Velocity",
97 )
98
99 make_val_env = lambda: AMAGOEnvWithVelocityName(
100 MyCustomHalfCheetahEval(
101 task_min_velocity=args.task_min_velocity,
102 task_max_velocity=args.task_max_velocity,
103 ),
104 # this would get replaced by the env_name property
105 # defined above.
106 env_name=f"HalfCheetahV4VelocityEval",
107 )
108
109 config = {}
110 # switch sequence model
111 traj_encoder_type = cli_utils.switch_traj_encoder(
112 config,
113 arch=args.traj_encoder,
114 memory_size=args.memory_size,
115 layers=args.memory_layers,
116 )
117 # switch agent
118 agent_type = cli_utils.switch_agent(
119 config,
120 args.agent_type,
121 reward_multiplier=1.0, # gym locomotion returns are already large
122 gamma=0.99, # locomotion policies don't need long horizons - fall back to the default
123 tau=0.005,
124 )
125 # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
126 # noise from most TD3/DPPG implementations.
127 exploration_type = cli_utils.switch_exploration(
128 config, "egreedy", steps_anneal=500_000
129 )
130 cli_utils.use_config(config, args.configs)
131
132 group_name = args.run_name
133 for trial in range(args.trials):
134 run_name = group_name + f"_trial_{trial}"
135 experiment = cli_utils.create_experiment_from_cli(
136 args,
137 make_train_env=make_train_env, # different train/val envs
138 make_val_env=make_val_env,
139 max_seq_len=args.policy_seq_len,
140 traj_save_len=args.policy_seq_len * 6,
141 run_name=run_name,
142 tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
143 traj_encoder_type=traj_encoder_type,
144 exploration_wrapper_type=exploration_type,
145 agent_type=agent_type,
146 group_name=group_name,
147 val_timesteps_per_epoch=args.eval_episodes_per_actor * 1001,
148 grad_clip=2.0,
149 learning_rate=3e-4,
150 )
151 experiment.start()
152 if args.ckpt is not None:
153 experiment.load_checkpoint(args.ckpt)
154 experiment.learn()
155 experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
156 experiment.delete_buffer_from_disk()
157 wandb.finish()