Source code for duo_ai.algorithms.logit

import logging
import pprint
from dataclasses import dataclass, field
from typing import Dict, List

import numpy as np
import torch

from duo_ai.core import Algorithm
from duo_ai.utils.global_variables import get_global_variable


[docs]@dataclass class LogitAlgorithmConfig: """ Configuration for the LogitAlgorithm, which tunes thresholds and temperatures for confidence-based policies. Parameters ---------- name : str, optional Name of the algorithm class. Default is "logit". num_rollouts : int, optional Number of rollouts to use for score generation. Default is 128. percentiles : list of float, optional List of percentiles to use for threshold selection. Default is range(0, 101, 10). explore_temps : list of float, optional List of temperatures to use during exploration rollouts. Default is [1.0]. score_temps : list of float, optional List of temperatures to use when scoring. Default is [1.0]. Examples -------- >>> config = LogitAlgorithmConfig() """ name: str = "logit" num_rollouts: int = 128 percentiles: List[float] = field(default_factory=lambda: list(range(0, 101, 10))) explore_temps: List[float] = field(default_factory=lambda: [1.0]) score_temps: List[float] = field(default_factory=lambda: [1.0])
[docs]class LogitAlgorithm(Algorithm): """ Algorithm for tuning confidence-based policies using logit thresholds and temperatures. Examples -------- >>> algo = LogitAlgorithm(LogitAlgorithmConfig()) """ config_cls = LogitAlgorithmConfig def __init__(self, config: LogitAlgorithmConfig) -> None: """ Initialize the LogitAlgorithm. Parameters ---------- config : LogitAlgorithmConfig Configuration object for the LogitAlgorithm. Returns ------- None Examples -------- >>> algo = LogitAlgorithm(LogitAlgorithmConfig()) """ self.config = config
[docs] def train( self, policy: "duo.core.Policy", env: "gym.Env", validators: Dict[str, "duo.core.Evaluator"], ) -> None: """ Train the LogitAlgorithm by searching for the best threshold and temperature parameters based on rollout scores and evaluation results. Parameters ---------- policy : duo.core.Policy The policy to be trained and evaluated. env : gym.Env The environment used for training and rollouts. validators : dict of str to duo.core.Evaluator Dictionary mapping split names to evaluator instances for evaluation. Returns ------- None Examples -------- >>> algorithm = LogitAlgorithm(LogitAlgorithmConfig()) >>> algorithm.train(policy, env, validators) """ config = self.config self.save_dir = get_global_variable("experiment_dir") best_params = {} best_result = {} for split in validators: best_result[split] = {"reward_mean": -float("inf")} self.score_fn = policy.compute_confidence for explore_temp in config.explore_temps: logging.info(f"Exploration temperature: {explore_temp}") for score_temp in config.score_temps: policy.set_params({"temperature": score_temp}) # Generate scores by rolling out (simulated) novice in training environment scores = self._generate_scores( env.base_env, env.novice, explore_temp, config.num_rollouts, ) thresholds = [np.percentile(scores, pct) for pct in config.percentiles] logging.info("Thresholds: " + pprint.pformat(thresholds, indent=2)) for threshold in thresholds: policy.set_params({"threshold": threshold}) cur_params = policy.get_params() logging.info("Parameters: " + pprint.pformat(cur_params, indent=2)) # Evaluate policy on all splits eval_result = {} for split, validator in validators.items(): logging.info(f"Evaluating on {split} split") eval_result[split] = validator.evaluate(policy) if ( eval_result[split]["reward_mean"] > best_result[split]["reward_mean"] ): best_params[split] = cur_params best_result[split] = eval_result[split] self.save_checkpoint(policy, f"best_{split}") # Log best result so far for split, validator in validators.items(): logging.info(f"BEST {split} so far") logging.info( "Parameters: " + pprint.pformat(best_params[split], indent=2) ) validator.summarizer.write(best_result[split])
[docs] def save_checkpoint(self, policy: "duo.core.Policy", name: str) -> None: """ Save the current policy configuration and parameters to a checkpoint file. Parameters ---------- policy : duo.core.Policy The policy whose parameters are to be saved. name : str Name for the checkpoint file. Returns ------- None Examples -------- >>> self.save_checkpoint(policy, "best_test") """ save_path = f"{self.save_dir}/{name}.ckpt" torch.save( { "policy_config": policy.config, "model_state_dict": policy.get_params(), }, save_path, ) logging.info(f"Saved checkpoint to {save_path}")
[docs] def _generate_scores( self, env: "gym.Env", policy: "duo.core.Policy", temperature: float, num_rollouts: int, ) -> list: """ Generate confidence scores by rolling out the policy in the environment. Parameters ---------- env : gym.Env The environment used for rollouts. policy : duo.core.Policy The policy to be evaluated. temperature : float Temperature parameter for action selection. num_rollouts : int Total number of rollout episodes to generate. Returns ------- scores : list of float List of confidence scores collected from rollouts. Examples -------- >>> scores = self._generate_scores(env, policy, 1.0, 128) """ @torch.no_grad() def rollout_once(): policy.eval() obs = env.reset() has_done = np.array([False] * env.num_envs) policy.reset(np.ones_like(has_done)) while not has_done.all(): action, model_output = policy.act( obs, temperature=temperature, return_model_output=True ) score = self.score_fn(model_output.logits) for i in range(env.num_envs): if not has_done[i]: scores.append(score[i].item()) obs, _, done, _ = env.step(action.cpu().numpy()) has_done |= done return scores assert ( num_rollouts % env.num_envs == 0 ), "LogitAlgorithm requires num_rollouts to be divisible by num_envs" scores = [] for i in range(num_rollouts // env.num_envs): rollout_once() return scores