duo_ai.models¶
Submodules¶
Attributes¶
Classes¶
PPO model for coordination environments, supporting multiple feature types. |
|
PPO model using an IMPALA encoder for feature extraction. |
Package Contents¶
- class duo_ai.models.ImpalaCoordPPOModel(config: ImpalaCoordPPOModelConfig, env: gym.Env)[source]¶
Bases:
torch.nn.ModulePPO model for coordination environments, supporting multiple feature types.
Examples
>>> model = ImpalaCoordPPOModel(ImpalaCoordPPOModelConfig(), env) >>> obs = {"base_obs": ..., "novice_hidden": ..., "novice_logits": ...} >>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- config_cls¶
- device¶
- embedder¶
- feature_type¶
- fc_policy¶
- fc_value¶
- logit_dim¶
- forward(obs: Dict[str, Any]) PPOModelOutput[source]¶
Forward pass of the ImpalaCoordPPOModel.
- Parameters:
obs (dict) – Dictionary containing observation components (base_obs, novice_hidden, novice_logits).
- Returns:
Output container with logits, value, and hidden features.
- Return type:
Examples
>>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- class duo_ai.models.ImpalaPPOModel(config: ImpalaPPOModelConfig, env: gym.Env)[source]¶
Bases:
torch.nn.ModulePPO model using an IMPALA encoder for feature extraction.
Examples
>>> model = ImpalaPPOModel(ImpalaPPOModelConfig(), env) >>> obs = torch.randn(8, 3, 64, 64) >>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- config_cls¶
- device¶
- embedder¶
- fc_policy¶
- fc_value¶
- logit_dim¶
- forward(obs: Any) PPOModelOutput[source]¶
Forward pass of the ImpalaPPOModel.
- Parameters:
obs (torch.Tensor or np.ndarray) – Observation input to the model.
- Returns:
Output container with logits, value, and hidden features.
- Return type:
Examples
>>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- duo_ai.models.registry¶