Add a New Algorithm¶
In this tutorial, you will learn how to implement and intergrate a new algorithm for learning a coordination policy.
We will implement a simple algorithm called AskEveryK, which learns a policy that asks for help every K steps. The algorithm searches for the best value of K from a set of candidates.
The code for this tutorial is provided at examples/procgen_ask_every_k.py. Try running it with:
python examples/procgen_ask_every_k.py --config configs/procgen_ask_every_k.yaml overwrite=1
1. Implement the Algorithm¶
We first implement the AskEveryKAlgorithm class, which is a subclass of duo_ai.core.Algorithm, along with its configuration dataclass:
from duo_ai.core import Algorithm
@dataclass
class AskEveryKAlgorithmConfig:
name: str = "ask_every_k"
candidates: List[int] = field(default_factory=lambda: [5, 10, 15, 20])
class AskEveryKAlgorithm(Algorithm):
config_cls = AskEveryKAlgorithmConfig
def __init__(self, config):
self.config = config
def train(self, policy, env, validators):
config = self.config
self.save_dir = get_global_variable("experiment_dir")
best_k = None
best_result = {}
for split in validators:
best_result[split] = {"reward_mean": -float("inf")}
# Loop through possible values of K and evaluate the corresponding policy
for k in config.candidates:
logging.info(f"Evaluating k={k}")
policy.set_params({"k": k})
for split, validator in validators.items():
result = validator.evaluate(policy)
if result["reward_mean"] > best_result[split]["reward_mean"]:
best_result[split] = result
best_k = k
self.save_checkpoint(policy, f"best_{split}")
for split, validator in validators.items():
logging.info(f"BEST result for {split} (k={best_k}):")
validator.summarizer.write(best_result[split])
def save_checkpoint(self, policy, name):
save_path = f"{self.save_dir}/{name}.ckpt"
torch.save(
{
"policy_config": policy.config,
"model_state_dict": policy.get_params(),
},
save_path,
)
logging.info(f"Saved checkpoint to {save_path}")
2. Implement the Policy¶
Next, we implement a sublass of duo_ai.core.Policy with a parameter K, which queries the expert every K steps.
from duo_ai.core import Policy
@dataclass
class AskEveryKPolicyConfig:
name: str = "ask_every_k"
load_path: Optional[str] = None
class AskEveryKPolicy(Policy):
config_cls = AskEveryKPolicyConfig
def __init__(self, config, env):
self.config = config
self.EXPERT = env.EXPERT
self.k = None
self.step = np.array([0] * env.num_envs)
self.device = get_global_variable("device")
def reset(self, done):
self.batch_size = len(done)
if self.batch_size < len(self.step):
self.step = self.step[: self.batch_size]
self.step[done] = 0
def act(self, obs, temperature=None):
batch_size = self.batch_size
assert obs["base_obs"].shape[0] == batch_size
action = torch.zeros(batch_size).long().to(self.device)
for i in range(batch_size):
if self.step[i] % self.k == 0:
action[i] = self.EXPERT
else:
action[i] = 1 - self.EXPERT
self.step[i] += 1
return action
def set_params(self, params):
self.k = params["k"]
def get_params(self):
return {"k": self.k}
def train(self):
pass
def eval(self):
pass
3. Register the Algorithm and Policy¶
Finally, we register the algorithm and the policy with Duo so that their configuration arguments are included in Duo’s argument list.
duo_ai.register_algorithm("ask_every_k", AskEveryKAlgorithm)
duo_ai.register_policy("ask_every_k", AskEveryKPolicy)
That covers all the major steps. The rest of the code follows the standard process for training a coordination policy.