1from argparse import ArgumentParser
2import random
3
4import torch
5import gym as og_gym
6import d4rl
7import gymnasium as gym
8import wandb
9import numpy as np
10
11import amago
12from amago.envs import AMAGOEnv
13from amago import cli_utils
14from amago.loading import RLData, RLDataset, DiskTrajDataset, MixtureOfDatasets
15from amago.nets.policy_dists import TanhGaussian, GMM, Beta
16from amago.nets.actor_critic import ResidualActor, Actor
17from amago.agent import binary_filter, exp_filter
18
19
20def add_cli(parser):
21 parser.add_argument(
22 "--env", type=str, required=True, help="Environment/Dataset name"
23 )
24 parser.add_argument(
25 "--max_seq_len", type=int, default=32, help="Policy sequence length."
26 )
27 parser.add_argument(
28 "--policy_dist",
29 type=str,
30 default="Beta",
31 help="Policy distribution type",
32 choices=["TanhGaussian", "GMM", "Beta"],
33 )
34 parser.add_argument(
35 "--actor_type",
36 type=str,
37 default="Actor",
38 help="Actor head type",
39 choices=["ResidualActor", "Actor"],
40 )
41 parser.add_argument(
42 "--online_after_epoch",
43 type=int,
44 default=float("inf"),
45 help="Number of epochs after which to start collecting online data",
46 )
47 parser.add_argument(
48 "--eval_timesteps",
49 type=int,
50 default=1000,
51 help="Number of timesteps to evaluate for each actor. Will be overridden if the environment has a known time limit.",
52 )
53 return parser
54
55
56class D4RLDataset(RLDataset):
57 def __init__(self, d4rl_dset: dict[str, np.ndarray]):
58 super().__init__()
59 self.d4rl_dset = d4rl_dset
60 self.episode_ends = np.where(d4rl_dset["terminals"] | d4rl_dset["timeouts"])[0]
61 self.ep_lens = self.episode_ends[1:] - self.episode_ends[:-1]
62 self.max_ep_len = self.ep_lens.max()
63
64 def get_description(self) -> str:
65 return "D4RL"
66
67 @property
68 def save_new_trajs_to(self):
69 # disables saving new amago trajectories to disk
70 return None
71
72 def sample_random_trajectory(self):
73 episode_idx = random.randrange(0, len(self.episode_ends) - 1)
74 return self._sample_trajectory(episode_idx)
75
76 def _sample_trajectory(self, episode_idx: int):
77 # pick a random episode
78 s = self.episode_ends[episode_idx] + 1
79 e = self.episode_ends[episode_idx + 1] + 1
80 traj_len = e - s
81 obs_np = self.d4rl_dset["observations"][s : e + 1]
82 actions_np = self.d4rl_dset["actions"][s:e]
83 rewards_np = self.d4rl_dset["rewards"][s:e]
84 terminals_np = self.d4rl_dset["terminals"][s:e]
85 timeouts_np = self.d4rl_dset["timeouts"][s:e]
86
87 # convert to torch, adding time_idxs
88 obs = {"observation": torch.from_numpy(obs_np)}
89 actions = torch.from_numpy(actions_np).float()
90 rewards = torch.from_numpy(rewards_np).float().unsqueeze(-1)
91 time_idxs = torch.arange(traj_len).unsqueeze(-1).long()
92 dones = torch.from_numpy(terminals_np).bool().unsqueeze(-1)
93 return RLData(
94 obs=obs,
95 actions=actions,
96 rews=rewards,
97 dones=dones,
98 time_idxs=time_idxs,
99 )
100
101
102from amago.envs.env_utils import space_convert
103from amago.envs.amago_env import AMAGO_ENV_LOG_PREFIX
104
105
106class D4RLGymEnv(gym.Env):
107 """
108 Light wrapper that logs the D4RL normalized return and handles
109 the gym/gymnasium conversion while we're at it.
110 """
111
112 def __init__(self, env_name: str):
113 # hack fix seeding for parallel envs
114 np.random.seed(random.randrange(1e6))
115 self.env_name = env_name
116 self.env = og_gym.make(env_name)
117 self.action_space = space_convert(self.env.action_space)
118 self.observation_space = space_convert(self.env.observation_space)
119 if isinstance(self.env, og_gym.wrappers.TimeLimit):
120 # this time limit is apparently not consistent with the datasets
121 self.time_limit = self.env._max_episode_steps
122 else:
123 self.time_limit = None
124 self.max_return = d4rl.infos.REF_MAX_SCORE[self.env_name]
125 self.min_return = d4rl.infos.REF_MIN_SCORE[self.env_name]
126
127 @property
128 def dset(self):
129 return self.env.get_dataset()
130
131 def reset(self, *args, **kwargs):
132 self.episode_return = 0
133 return self.env.reset(), {}
134
135 def step(self, action):
136 s, r, d, i = self.env.step(action)
137 truncated = i.get("TimeLimit.truncated", False)
138 terminated = d and not truncated
139 self.episode_return += r
140 if terminated or truncated:
141 i[f"{AMAGO_ENV_LOG_PREFIX} D4RL Normalized Return"] = (
142 d4rl.get_normalized_score(self.env_name, self.episode_return)
143 )
144 return s, r, terminated, truncated, i
145
146
147if __name__ == "__main__":
148 parser = ArgumentParser()
149 cli_utils.add_common_cli(parser)
150 add_cli(parser)
151 args = parser.parse_args()
152
153 # ues env to set some args
154 env_name = args.env
155 example_env = D4RLGymEnv(args.env)
156 assert isinstance(
157 example_env.action_space, gym.spaces.Box
158 ), "Only supports continuous action spaces"
159
160 args.eval_timesteps = example_env.time_limit + 1
161 args.timesteps_per_epoch = example_env.time_limit
162
163 # setup environment
164 make_train_env = lambda: AMAGOEnv(
165 D4RLGymEnv(args.env),
166 env_name=env_name,
167 batched_envs=1,
168 )
169
170 # agent architecture: drop everything down to standard small sizes
171 config = {
172 "amago.nets.actor_critic.NCritics.d_hidden": 128,
173 "amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 128,
174 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
175 "amago.nets.actor_critic.Actor.d_hidden": 128,
176 "amago.nets.actor_critic.Actor.continuous_dist_type": eval(args.policy_dist),
177 "amago.nets.actor_critic.ResidualActor.feature_dim": 128,
178 "amago.nets.actor_critic.ResidualActor.residual_ff_dim": 256,
179 "amago.nets.actor_critic.ResidualActor.residual_blocks": 2,
180 "amago.nets.actor_critic.ResidualActor.continuous_dist_type": eval(
181 args.policy_dist
182 ),
183 }
184 tstep_encoder_type = cli_utils.switch_tstep_encoder(
185 config,
186 arch="ff",
187 d_hidden=128,
188 d_output=128,
189 n_layers=1,
190 )
191 exploration_wrapper_type = cli_utils.switch_exploration(
192 config,
193 strategy="egreedy",
194 eps_start=0.05,
195 eps_end=0.01,
196 steps_anneal=15_000,
197 )
198 traj_encoder_type = cli_utils.switch_traj_encoder(
199 config,
200 arch=args.traj_encoder,
201 memory_size=args.memory_size,
202 layers=args.memory_layers,
203 )
204 agent_type = cli_utils.switch_agent(
205 config,
206 args.agent_type,
207 online_coeff=0.0,
208 offline_coeff=1.0,
209 gamma=0.997,
210 reward_multiplier=100.0 if example_env.max_return <= 10.0 else 1,
211 num_actions_for_value_in_critic_loss=3,
212 num_actions_for_value_in_actor_loss=5,
213 num_critics=5,
214 actor_type=eval(args.actor_type),
215 fbc_filter_func=exp_filter,
216 )
217 cli_utils.use_config(config, args.configs)
218
219 group_name = f"{args.run_name}_{env_name}"
220 for trial in range(args.trials):
221 run_name = group_name + f"_trial_{trial}"
222
223 # create dataset
224 d4rl_dataset = D4RLDataset(d4rl_dset=example_env.dset)
225 online_dset = DiskTrajDataset(
226 dset_root=args.buffer_dir,
227 dset_name=run_name,
228 dset_min_size=250,
229 dset_max_size=args.dset_max_size,
230 )
231 combined_dset = MixtureOfDatasets(
232 datasets=[d4rl_dataset, online_dset],
233 # skew sampling towards the demos 80/20
234 sampling_weights=[0.8, 0.2],
235 # gradually increase the weight of the online dset
236 # over the first 100 epochs *after online collection starts*
237 smooth_sudden_starts=50,
238 )
239
240 experiment = cli_utils.create_experiment_from_cli(
241 args,
242 make_train_env=make_train_env,
243 make_val_env=make_train_env,
244 max_seq_len=args.max_seq_len,
245 run_name=run_name,
246 tstep_encoder_type=tstep_encoder_type,
247 traj_encoder_type=traj_encoder_type,
248 agent_type=agent_type,
249 group_name=group_name,
250 val_timesteps_per_epoch=args.eval_timesteps,
251 learning_rate=1e-4,
252 dataset=combined_dset,
253 padded_sampling="right",
254 start_collecting_at_epoch=args.online_after_epoch,
255 stagger_traj_file_lengths=False,
256 traj_save_len=args.eval_timesteps + 1,
257 sample_actions=False,
258 exploration_wrapper_type=exploration_wrapper_type,
259 )
260 # save a copy of this script at the time of the run
261 experiment = cli_utils.switch_async_mode(experiment, args.mode)
262 experiment.start()
263 if args.ckpt is not None:
264 experiment.load_checkpoint(args.ckpt)
265 experiment.learn()
266 experiment.evaluate_test(make_train_env, timesteps=10_000, render=False)
267 wandb.finish()