duo_ai.models.ppo ================= .. py:module:: duo_ai.models.ppo Classes ------- .. autoapisummary:: duo_ai.models.ppo.PPOModelOutput duo_ai.models.ppo.ImpalaPPOModelConfig duo_ai.models.ppo.ImpalaPPOModel duo_ai.models.ppo.ImpalaCoordPPOModelConfig duo_ai.models.ppo.ImpalaCoordPPOModel Module Contents --------------- .. py:class:: PPOModelOutput Output container for PPO model forward pass. :param logits: The raw action logits output by the policy head. :type logits: torch.Tensor :param value: The value function prediction output by the value head. :type value: torch.Tensor :param hidden: The hidden feature representation from the model. :type hidden: torch.Tensor .. rubric:: Examples >>> output = PPOModelOutput(logits, value, hidden) .. py:attribute:: logits :type: torch.Tensor .. py:attribute:: value :type: torch.Tensor .. py:attribute:: hidden :type: torch.Tensor .. py:class:: ImpalaPPOModelConfig Configuration dataclass for ImpalaPPOModel. :param name: Name of the model class. Default is "ImpalaPPOModel". :type name: str, optional .. rubric:: Examples >>> config = ImpalaPPOModelConfig() .. py:attribute:: name :type: str :value: 'impala_ppo' .. py:class:: ImpalaPPOModel(config: ImpalaPPOModelConfig, env: gym.Env) Bases: :py:obj:`torch.nn.Module` PPO model using an IMPALA encoder for feature extraction. .. rubric:: Examples >>> model = ImpalaPPOModel(ImpalaPPOModelConfig(), env) >>> obs = torch.randn(8, 3, 64, 64) >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:attribute:: config_cls .. py:attribute:: device .. py:attribute:: embedder .. py:attribute:: hidden_dim :value: 256 .. py:attribute:: fc_policy .. py:attribute:: fc_value .. py:attribute:: logit_dim .. py:method:: forward(obs: Any) -> PPOModelOutput Forward pass of the ImpalaPPOModel. :param obs: Observation input to the model. :type obs: torch.Tensor or np.ndarray :returns: Output container with logits, value, and hidden features. :rtype: PPOModelOutput .. rubric:: Examples >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:class:: ImpalaCoordPPOModelConfig Configuration dataclass for ImpalaCoordPPOModel. :param name: Name of the model class. Default is "ImpalaCoordPPOModel". :type name: str, optional :param feature_type: Type of feature representation to use. Options include: "obs", "hidden", "hidden_obs", "dist", "hidden_dist", "obs_dist", "obs_hidden_dist". Default is "obs". :type feature_type: str, optional .. rubric:: Examples >>> config = ImpalaCoordPPOModelConfig(feature_type="obs_hidden_dist") .. py:attribute:: name :type: str :value: 'impala_coord_ppo' .. py:attribute:: feature_type :type: str :value: 'obs' .. py:class:: ImpalaCoordPPOModel(config: ImpalaCoordPPOModelConfig, env: gym.Env) Bases: :py:obj:`torch.nn.Module` PPO model for coordination environments, supporting multiple feature types. .. rubric:: Examples >>> model = ImpalaCoordPPOModel(ImpalaCoordPPOModelConfig(), env) >>> obs = {"base_obs": ..., "novice_hidden": ..., "novice_logits": ...} >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:attribute:: config_cls .. py:attribute:: device .. py:attribute:: embedder .. py:attribute:: feature_type .. py:attribute:: fc_policy .. py:attribute:: fc_value .. py:attribute:: logit_dim .. py:method:: forward(obs: Dict[str, Any]) -> PPOModelOutput Forward pass of the ImpalaCoordPPOModel. :param obs: Dictionary containing observation components (base_obs, novice_hidden, novice_logits). :type obs: dict :returns: Output container with logits, value, and hidden features. :rtype: PPOModelOutput .. rubric:: Examples >>> out = model(obs) >>> print(out.logits.shape, out.value.shape)