duo_ai.models ============= .. py:module:: duo_ai.models Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/duo_ai/models/impala/index /autoapi/duo_ai/models/ppo/index Attributes ---------- .. autoapisummary:: duo_ai.models.registry Classes ------- .. autoapisummary:: duo_ai.models.ImpalaCoordPPOModel duo_ai.models.ImpalaPPOModel Package Contents ---------------- .. py:class:: ImpalaCoordPPOModel(config: ImpalaCoordPPOModelConfig, env: gym.Env) Bases: :py:obj:`torch.nn.Module` PPO model for coordination environments, supporting multiple feature types. .. rubric:: Examples >>> model = ImpalaCoordPPOModel(ImpalaCoordPPOModelConfig(), env) >>> obs = {"base_obs": ..., "novice_hidden": ..., "novice_logits": ...} >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:attribute:: config_cls .. py:attribute:: device .. py:attribute:: embedder .. py:attribute:: feature_type .. py:attribute:: fc_policy .. py:attribute:: fc_value .. py:attribute:: logit_dim .. py:method:: forward(obs: Dict[str, Any]) -> PPOModelOutput Forward pass of the ImpalaCoordPPOModel. :param obs: Dictionary containing observation components (base_obs, novice_hidden, novice_logits). :type obs: dict :returns: Output container with logits, value, and hidden features. :rtype: PPOModelOutput .. rubric:: Examples >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:class:: ImpalaPPOModel(config: ImpalaPPOModelConfig, env: gym.Env) Bases: :py:obj:`torch.nn.Module` PPO model using an IMPALA encoder for feature extraction. .. rubric:: Examples >>> model = ImpalaPPOModel(ImpalaPPOModelConfig(), env) >>> obs = torch.randn(8, 3, 64, 64) >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:attribute:: config_cls .. py:attribute:: device .. py:attribute:: embedder .. py:attribute:: hidden_dim :value: 256 .. py:attribute:: fc_policy .. py:attribute:: fc_value .. py:attribute:: logit_dim .. py:method:: forward(obs: Any) -> PPOModelOutput Forward pass of the ImpalaPPOModel. :param obs: Observation input to the model. :type obs: torch.Tensor or np.ndarray :returns: Output container with logits, value, and hidden features. :rtype: PPOModelOutput .. rubric:: Examples >>> out = model(obs) >>> print(out.logits.shape, out.value.shape) .. py:data:: registry