Source code for duo_ai.core.environment

import os
from copy import deepcopy as dc
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

# Import gym or gymnasium based on environment variable
if os.environ.get("GYM_BACKEND", "gym") == "gymnasium":
    import gymnasium as gym
else:
    import gym

import numpy as np
import torch


[docs]@dataclass class CoordinationConfig: """ Configuration for coordination environment parameters. Parameters ---------- expert_query_cost_weight : float, optional The cost coefficient for querying the expert policy. Default is 0.4. switch_agent_cost_weight : float, optional The cost coefficient for switching between agents. Default is 0.0. temperature : float, optional The temperature parameter for action sampling. Default is 1.0. Examples -------- >>> config = CoordinationConfig() """ expert_query_cost_weight: float = 0.4 switch_agent_cost_weight: float = 0.0 temperature: float = 1.0
[docs]class CoordEnv(gym.Env): """ Environment for coordinating between novice and expert policies. This class wraps a base environment and enables switching between a novice and expert policy, applying costs for expert queries and agent switching. Examples -------- >>> config = CoordinationConfig() >>> base_env = gym.make(...) >>> novice = ... >>> expert = ... >>> env = CoordEnv(config, base_env, novice, expert) """ config_cls = CoordinationConfig NOVICE = 0 EXPERT = 1 def __init__( self, config: CoordinationConfig, base_env: gym.Env, novice: "duo.core.Policy", expert: "duo.core.Policy", open_novice: bool = True, open_expert: bool = False, ) -> None: """ Initialize the coordination environment. Parameters ---------- config : CoordinationConfig Configuration object specifying coordination parameters. base_env : gym.Env The base Gym environment to be wrapped or extended. novice : duo.core.Policy The novice policy. expert : duo.core.Policy The expert policy. open_novice : bool, optional Whether to expose novice outputs in observations. Default is True. open_expert : bool, optional Whether to expose expert outputs in observations. Default is False. Returns ------- None Examples -------- >>> config = CoordinationConfig(...) >>> base_env = gym.make(...) >>> novice = ... >>> expert = ... >>> env = CoordEnv(config, base_env, novice, expert) """ self.config = config self.base_env = base_env self.novice = novice self.expert = expert self.open_novice = open_novice self.open_expert = open_expert self.action_space = gym.spaces.Discrete(2) self.observation_space = gym.spaces.Dict( { "base_obs": base_env.observation_space, "novice_hidden": gym.spaces.Box( -100, 100, shape=(novice.model.hidden_dim,) ), "novice_logits": gym.spaces.Box( -100, 100, shape=(novice.model.logit_dim,) ), } ) self.expert_query_cost_per_action = None self.switch_agent_cost_per_action = None @property def num_envs(self) -> int: """ Number of parallel environments. Returns ------- int Number of parallel environments. Examples -------- >>> n = env.num_envs """ return self.base_env.num_envs
[docs] def set_costs(self, base_penalty: float) -> None: """ Set the cost per action for expert queries and agent switching. Parameters ---------- base_penalty : float The reward value per action. Returns ------- None Examples -------- >>> env.set_costs(0.05) """ # NOTE: paper results were generated with rounding but here we don't self.expert_query_cost_per_action = ( base_penalty * self.config.expert_query_cost_weight ) self.switch_agent_cost_per_action = ( base_penalty * self.config.switch_agent_cost_weight )
[docs] def reset(self) -> Dict[str, Any]: """ Reset the coordination environment to an initial state. Returns ------- dict The initial observation of the environment, including: - "base_obs": The initial observation from the base environment. - "novice_hidden": Numpy array of hidden features from the novice policy. - "novice_logits": Numpy array of output logits from the novice policy. - "expert_hidden": Numpy array of hidden features from the expert policy (if open_expert). - "expert_logits": Numpy array of output logits from the expert policy (if open_expert). Examples -------- >>> obs = env.reset() """ self.prev_action = None self.base_obs = self.base_env.reset() self.novice.model.eval() self.expert.model.eval() self._reset_agents(done=np.array([True] * self.num_envs)) return self._get_obs()
[docs] def _reset_agents(self, done: np.ndarray) -> None: """ Reset the internal state of the novice and expert agents. Parameters ---------- done : numpy.ndarray Boolean array indicating which episodes in a batch require a reset. Returns ------- None Examples -------- >>> env._reset_agents(np.array([True, False])) """ self.novice.reset(done) self.expert.reset(done)
[docs] def step( self, action: np.ndarray ) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, List[Dict[str, Any]]]: """ Advance the environment by one step using the provided action. Parameters ---------- action : numpy.ndarray The action(s) to take in the environment. Should be a numpy array indicating which agent acts. Returns ------- obs : dict The next observation of the environment, including: - "base_obs": The observation from the base environment. - "novice_hidden": Numpy array of hidden features from the novice policy. - "novice_logits": Numpy array of output logits from the novice policy. - "expert_hidden": Numpy array of hidden features from the expert policy (if open_expert). - "expert_logits": Numpy array of output logits from the expert policy (if open_expert). reward : numpy.ndarray The reward(s) obtained from the environment after taking the action. done : numpy.ndarray Boolean flag(s) indicating whether the episode has ended for each environment. info : list of dict Additional information from the environment for each agent or environment instance. Raises ------ Exception Propagates any exceptions raised by the underlying environment's `step` method. Examples -------- >>> obs, reward, done, info = env.step(action) """ base_action = self._compute_base_action(action) self.base_obs, base_reward, done, base_info = self.base_env.step(base_action) info = dc(base_info) for i, item in enumerate(info): if "base_reward" not in item: item["base_reward"] = base_reward[i] item["base_action"] = base_action[i] reward = self._get_reward(base_reward, action, done) self._reset_agents(done) self.prev_action = action return self._get_obs(), reward, done, info
[docs] @torch.no_grad() def _compute_base_action(self, action: np.ndarray) -> np.ndarray: """ Compute the environment-specific action for each agent. Parameters ---------- action : numpy.ndarray Array indicating which agent (novice or expert) acts for each environment. Returns ------- numpy.ndarray Array of actions to be passed to the base environment. Examples -------- >>> base_action = env._compute_base_action(action) """ is_novice = action == self.NOVICE is_expert = np.logical_not(is_novice) base_action = np.zeros_like(action) if is_novice.any(): base_action[is_novice] = ( self.novice.act( self.base_obs[is_novice], temperature=self.config.temperature ) .cpu() .numpy() ) if is_expert.any(): base_action[is_expert] = ( self.expert.act( self.base_obs[is_expert], temperature=self.config.temperature ) .cpu() .numpy() ) return base_action
[docs] @torch.no_grad() def _get_obs(self) -> Dict[str, Any]: """ Return the current observation for the coordination environment. Returns ------- dict A dictionary containing: - "base_obs": The current observation from the base environment. - "novice_hidden": Numpy array of hidden features from the novice policy (if open_novice). - "novice_logits": Numpy array of output logits from the novice policy (if open_novice). - "expert_hidden": Numpy array of hidden features from the expert policy (if open_expert). - "expert_logits": Numpy array of output logits from the expert policy (if open_expert). Examples -------- >>> obs = env._get_obs() """ # NOTE: models must be state-less. Models with a recurrent state should not be used here. obs = {"base_obs": self.base_obs} if self.open_novice: novice_output = self.novice.model(self.base_obs) obs["novice_hidden"] = novice_output.hidden.cpu().numpy() obs["novice_logits"] = novice_output.logits.cpu().numpy() if self.open_expert: expert_output = self.expert.model(self.base_obs) obs["expert_hidden"] = expert_output.hidden.cpu().numpy() obs["expert_logits"] = expert_output.logits.cpu().numpy() return obs
[docs] def _get_reward( self, base_reward: np.ndarray, action: np.ndarray, done: np.ndarray ) -> np.ndarray: """ Compute the reward for the current step, including costs for expert queries and agent switching. Parameters ---------- base_reward : numpy.ndarray The base reward from the environment. action : numpy.ndarray The action(s) taken (novice or expert). done : numpy.ndarray Boolean flag(s) indicating whether the episode has ended for each environment. Returns ------- numpy.ndarray The computed reward(s) after applying costs. Examples -------- >>> reward = env._get_reward(base_reward, action, done) """ # cost of querying expert agent reward = np.where( action == self.EXPERT, base_reward - self.expert_query_cost_per_action, base_reward, ) # cost of switching if self.prev_action is not None: switch_indices = ((action != self.prev_action) & (~done)).nonzero()[0] if switch_indices.size > 1: reward[switch_indices] -= self.switch_agent_cost_per_action return reward
[docs] def close(self) -> None: """ Close the coordination environment and release any resources held. Returns ------- None Examples -------- >>> env.close() """ return self.base_env.close()
[docs]class GeneralCoordEnv(CoordEnv): """ Coordination environment supporting recurrent policies. This class supports policies that maintain a hidden state across steps, but can be less efficient for stateless policies than `CoordEnv`. Examples -------- >>> config = CoordinationConfig() >>> base_env = gym.make(...) >>> novice = ... >>> expert = ... >>> env = GeneralCoordEnv(config, base_env, novice, expert) """
[docs] @torch.no_grad() def _compute_agents_action(self) -> np.ndarray: """ Compute the actions for both novice and expert agents, supporting recurrent policies. Returns ------- numpy.ndarray Array of actions to be passed to the base environment. Examples -------- >>> base_action = env._compute_agents_action() """ self.novice_action, self.novice_output = self.novice.act( self.base_obs, temperature=self.config.temperature, return_model_output=True, ) self.expert_action, self.expert_output = self.expert.act( self.base_obs, temperature=self.config.temperature, return_model_output=True, ) self.novice_action = self.novice_action.cpu().numpy() self.expert_action = self.expert_action.cpu().numpy()
[docs] @torch.no_grad() def _compute_base_action(self, action: np.ndarray) -> np.ndarray: """ Compute the environment-specific action for each agent, supporting recurrent policies. Parameters ---------- action : numpy.ndarray Array indicating which agent (novice or expert) acts for each environment. Returns ------- numpy.ndarray Array of actions to be passed to the base environment. Examples -------- >>> base_action = env._compute_base_action(action) """ is_novice = action == self.NOVICE is_expert = np.logical_not(is_novice) base_action = np.zeros_like(action) base_action[is_novice] = self.novice_action[is_novice] base_action[is_expert] = self.expert_action[is_expert] return base_action
[docs] def _get_obs(self) -> Dict[str, Any]: """ Return the current observation for the coordination environment, supporting recurrent policies. Returns ------- dict A dictionary containing: - "base_obs": The current observation from the base environment. - "novice_hidden": Numpy array of hidden features from the novice policy (if open_novice). - "novice_logits": Numpy array of output logits from the novice policy (if open_novice). - "expert_hidden": Numpy array of hidden features from the expert policy (if open_expert). - "expert_logits": Numpy array of output logits from the expert policy (if open_expert). Examples -------- >>> obs = env._get_obs() """ self._compute_agents_action() obs = {"base_obs": self.base_obs} if self.open_novice: obs["novice_hidden"] = self.novice_output.hidden.cpu().numpy() obs["novice_logits"] = self.novice_output.logits.cpu().numpy() if self.open_expert: obs["expert_hidden"] = self.expert_output.hidden.cpu().numpy() obs["expert_logits"] = self.expert_output.logits.cpu().numpy() return obs