Source code for duo_ai.models.ppo

from dataclasses import dataclass
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn

from duo_ai.models.impala import Impala
from duo_ai.utils.global_variables import get_global_variable
from duo_ai.utils.model import orthogonal_init


[docs]@dataclass class PPOModelOutput: """ Output container for PPO model forward pass. Parameters ---------- logits : torch.Tensor The raw action logits output by the policy head. value : torch.Tensor The value function prediction output by the value head. hidden : torch.Tensor The hidden feature representation from the model. Examples -------- >>> output = PPOModelOutput(logits, value, hidden) """ logits: torch.Tensor value: torch.Tensor hidden: torch.Tensor
[docs]@dataclass class ImpalaPPOModelConfig: """ Configuration dataclass for ImpalaPPOModel. Parameters ---------- name : str, optional Name of the model class. Default is "ImpalaPPOModel". Examples -------- >>> config = ImpalaPPOModelConfig() """ name: str = "impala_ppo"
[docs]class ImpalaPPOModel(nn.Module): """ PPO model using an IMPALA encoder for feature extraction. Examples -------- >>> model = ImpalaPPOModel(ImpalaPPOModelConfig(), env) >>> obs = torch.randn(8, 3, 64, 64) >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) """ config_cls = ImpalaPPOModelConfig def __init__(self, config: ImpalaPPOModelConfig, env: "gym.Env") -> None: """ Initialize the ImpalaPPOModel. Parameters ---------- config : ImpalaPPOModelConfig Configuration object for the model. env : gym.Env The environment instance, used to determine input and output dimensions. Returns ------- None Examples -------- >>> model = ImpalaPPOModel(ImpalaPPOModelConfig(), env) """ super().__init__() self.device = get_global_variable("device") self.embedder = Impala(env.observation_space.shape) self.hidden_dim = self.embedder.output_dim self.fc_policy = orthogonal_init( nn.Linear(self.hidden_dim, env.action_space.n), gain=0.01 ) self.fc_value = orthogonal_init(nn.Linear(self.hidden_dim, 1), gain=1.0) self.logit_dim = env.action_space.n
[docs] def forward(self, obs: Any) -> PPOModelOutput: """ Forward pass of the ImpalaPPOModel. Parameters ---------- obs : torch.Tensor or np.ndarray Observation input to the model. Returns ------- PPOModelOutput Output container with logits, value, and hidden features. Examples -------- >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) """ if not torch.is_tensor(obs): obs = torch.FloatTensor(obs).to(device=self.device) hidden = self.embedder(obs) logits = self.fc_policy(hidden) value = self.fc_value(hidden).reshape(-1) return PPOModelOutput(logits, value, hidden)
[docs]@dataclass class ImpalaCoordPPOModelConfig: """ Configuration dataclass for ImpalaCoordPPOModel. Parameters ---------- name : str, optional Name of the model class. Default is "ImpalaCoordPPOModel". feature_type : str, optional Type of feature representation to use. Options include: "obs", "hidden", "hidden_obs", "dist", "hidden_dist", "obs_dist", "obs_hidden_dist". Default is "obs". Examples -------- >>> config = ImpalaCoordPPOModelConfig(feature_type="obs_hidden_dist") """ name: str = "impala_coord_ppo" feature_type: str = "obs"
[docs]class ImpalaCoordPPOModel(nn.Module): """ PPO model for coordination environments, supporting multiple feature types. Examples -------- >>> model = ImpalaCoordPPOModel(ImpalaCoordPPOModelConfig(), env) >>> obs = {"base_obs": ..., "novice_hidden": ..., "novice_logits": ...} >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) """ config_cls = ImpalaCoordPPOModelConfig def __init__(self, config: ImpalaCoordPPOModelConfig, env: "gym.Env") -> None: """ Initialize the ImpalaCoordPPOModel. Parameters ---------- config : ImpalaCoordPPOModelConfig Configuration object for the model. env : gym.Env The coordination environment instance. Returns ------- None Examples -------- >>> model = ImpalaCoordPPOModel(ImpalaCoordPPOModelConfig(), env) """ super().__init__() self.device = get_global_variable("device") self.embedder = Impala(env.base_env.observation_space.shape) self.feature_type = config.feature_type if self.feature_type == "obs": self.hidden_dim = self.embedder.output_dim elif self.feature_type == "hidden": self.hidden_dim = env.novice.hidden_dim elif self.feature_type == "hidden_obs": self.hidden_dim = self.embedder.output_dim + env.novice.hidden_dim elif self.feature_type == "dist": self.hidden_dim = env.base_env.action_space.n elif self.feature_type == "hidden_dist": self.hidden_dim = env.novice.hidden_dim + env.base_env.action_space.n elif self.feature_type == "obs_dist": self.hidden_dim = self.embedder.output_dim + env.base_env.action_space.n elif self.feature_type == "obs_hidden_dist": self.hidden_dim = ( self.embedder.output_dim + env.novice.hidden_dim + env.base_env.action_space.n ) else: raise NotImplementedError self.fc_policy = orthogonal_init( nn.Linear(self.hidden_dim, env.action_space.n), gain=0.01 ) self.fc_value = orthogonal_init(nn.Linear(self.hidden_dim, 1), gain=1.0) self.logit_dim = env.action_space.n
[docs] def forward(self, obs: Dict[str, Any]) -> PPOModelOutput: """ Forward pass of the ImpalaCoordPPOModel. Parameters ---------- obs : dict Dictionary containing observation components (base_obs, novice_hidden, novice_logits). Returns ------- PPOModelOutput Output container with logits, value, and hidden features. Examples -------- >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) """ base_obs = obs["base_obs"] if not torch.is_tensor(base_obs): base_obs = torch.from_numpy(base_obs).float().to(self.device) novice_hidden = obs["novice_hidden"] if not torch.is_tensor(novice_hidden): novice_hidden = torch.from_numpy(novice_hidden).float().to(self.device) novice_logits = obs["novice_logits"] if not torch.is_tensor(novice_logits): novice_logits = torch.from_numpy(novice_logits).float().to(self.device) if self.feature_type == "obs": hidden = self.embedder(base_obs) elif self.feature_type == "hidden": hidden = novice_hidden elif self.feature_type == "hidden_obs": hidden = torch.cat([self.embedder(base_obs), novice_hidden], dim=-1) elif self.feature_type == "dist": hidden = novice_logits.softmax(dim=-1) elif self.feature_type == "hidden_dist": hidden = torch.cat([novice_hidden, novice_logits.softmax(dim=-1)], dim=-1) elif self.feature_type == "obs_dist": hidden = torch.cat( [self.embedder(base_obs), novice_logits.softmax(dim=-1)], dim=-1 ) elif self.feature_type == "obs_hidden_dist": hidden = torch.cat( [ self.embedder(base_obs), novice_hidden, novice_logits.softmax(dim=-1), ], dim=-1, ) else: raise NotImplementedError logits = self.fc_policy(hidden) value = self.fc_value(hidden).reshape(-1) return PPOModelOutput(logits, value, hidden)