1from argparse import ArgumentParser
2import math
3import random
4
5import wandb
6
7import amago
8from amago.envs import AMAGOEnv
9from amago.envs.builtin.half_cheetah_v4_vel import HalfCheetahV4_MetaVelocity
10from amago.cli_utils import *
11
12
13def add_cli(parser):
14 parser.add_argument(
15 "--policy_seq_len", type=int, default=32, help="Policy sequence length."
16 )
17 parser.add_argument(
18 "--eval_episodes_per_actor",
19 type=int,
20 default=1,
21 help="Validation episodes per parallel actor.",
22 )
23 parser.add_argument(
24 "--task_min_velocity",
25 type=float,
26 default=0.0,
27 help="Min running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 0.",
28 )
29 parser.add_argument(
30 "--task_max_velocity",
31 type=float,
32 default=3.0,
33 help="Max running velocity the cheetah needs to be capable of to solve the meta-learning problem. Original benchmark used 3. Agents in the default locomotion env (no reward randomization) reach > 10.",
34 )
35 return parser
36
37
38"""
39Because this task is so similar to the other gymnasium examples, this example script is overly
40verbose about showing how you could customize the environment and create a train/test split.
41
42If you don't edit anything, this only becomes a longer way to train/test on the default task
43distribution (which is to sample a velocity uniformly between: [args.task_min_velocity, args.task_max_velocity])
44"""
45
46
47class MyCustomHalfCheetahTrain(HalfCheetahV4_MetaVelocity):
48
49 def sample_target_velocity(self) -> float:
50 # be sure to use `random` or be careful about np default_rng to ensure
51 # tasks are different across async parallel actors!
52 vel = super().sample_target_velocity() # random.uniform(min_vel, max_vel)
53 return vel
54
55
56class MyCustomHalfCheetahEval(HalfCheetahV4_MetaVelocity):
57
58 def sample_target_velocity(self) -> float:
59 vel = super().sample_target_velocity()
60 # or, to create OOD eval tasks:
61 # vel = random.uniform(self.task_min_velocity, self.task_max_velocity * 10.0)
62 # or random.choice([0., 1., self.task_max_velocity * 1.2]), etc.
63 return vel
64
65
66class AMAGOEnvWithVelocityName(AMAGOEnv):
67 """
68 Every eval metric gets logged based on the current
69 `env_name`. You could use this to log metrics for
70 different tasks separately. They get averaged over
71 all the evals with the same name, so you want a discrete
72 number of names that will get sample sizes > 1.
73 """
74
75 @property
76 def env_name(self) -> str:
77 current_task_vel = self.env.unwrapped.target_velocity
78 # need to discretize this somehow; just one example
79 low, high = math.floor(current_task_vel), math.ceil(current_task_vel)
80 return f"HalfCheetahVelocity-Vel-[{low}, {high}]"
81
82
83if __name__ == "__main__":
84 parser = ArgumentParser()
85 add_common_cli(parser)
86 add_cli(parser)
87 args = parser.parse_args()
88
89 # setup environment
90 make_train_env = lambda: AMAGOEnvWithVelocityName(
91 MyCustomHalfCheetahTrain(
92 task_min_velocity=args.task_min_velocity,
93 task_max_velocity=args.task_max_velocity,
94 ),
95 # the env_name is totally arbitrary and only impacts logging / data filenames
96 env_name=f"HalfCheetahV4Velocity",
97 )
98
99 make_val_env = lambda: AMAGOEnvWithVelocityName(
100 MyCustomHalfCheetahEval(
101 task_min_velocity=args.task_min_velocity,
102 task_max_velocity=args.task_max_velocity,
103 ),
104 # this would get replaced by the env_name property
105 # defined above.
106 env_name=f"HalfCheetahV4VelocityEval",
107 )
108
109 config = {}
110 # switch sequence model
111 traj_encoder_type = switch_traj_encoder(
112 config,
113 arch=args.traj_encoder,
114 memory_size=args.memory_size,
115 layers=args.memory_layers,
116 )
117 # switch agent
118 agent_type = switch_agent(
119 config,
120 args.agent_type,
121 reward_multiplier=1.0, # gym locomotion returns are already large
122 gamma=0.99, # locomotion policies don't need long horizons - fall back to the default
123 tau=0.005,
124 )
125 # "egreedy" exploration in continuous control is just the epsilon-scheduled random (normal)
126 # noise from most TD3/DPPG implementations.
127 exploration_type = switch_exploration(config, "egreedy", steps_anneal=500_000)
128 use_config(config, args.configs)
129
130 group_name = args.run_name
131 for trial in range(args.trials):
132 run_name = group_name + f"_trial_{trial}"
133 experiment = create_experiment_from_cli(
134 args,
135 make_train_env=make_train_env, # different train/val envs
136 make_val_env=make_val_env,
137 max_seq_len=args.policy_seq_len,
138 traj_save_len=args.policy_seq_len * 6,
139 run_name=run_name,
140 tstep_encoder_type=amago.nets.tstep_encoders.FFTstepEncoder,
141 traj_encoder_type=traj_encoder_type,
142 exploration_wrapper_type=exploration_type,
143 agent_type=agent_type,
144 group_name=group_name,
145 val_timesteps_per_epoch=args.eval_episodes_per_actor * 1001,
146 grad_clip=2.0,
147 learning_rate=3e-4,
148 )
149 experiment.start()
150 if args.ckpt is not None:
151 experiment.load_checkpoint(args.ckpt)
152 experiment.learn()
153 experiment.evaluate_test(make_val_env, timesteps=10_000, render=False)
154 experiment.delete_buffer_from_disk()
155 wandb.finish()