duo_ai.models.ppo¶
Classes¶
Output container for PPO model forward pass. |
|
Configuration dataclass for ImpalaPPOModel. |
|
PPO model using an IMPALA encoder for feature extraction. |
|
Configuration dataclass for ImpalaCoordPPOModel. |
|
PPO model for coordination environments, supporting multiple feature types. |
Module Contents¶
- class duo_ai.models.ppo.PPOModelOutput[source]¶
Output container for PPO model forward pass.
- Parameters:
logits (torch.Tensor) – The raw action logits output by the policy head.
value (torch.Tensor) – The value function prediction output by the value head.
hidden (torch.Tensor) – The hidden feature representation from the model.
Examples
>>> output = PPOModelOutput(logits, value, hidden)
- logits: torch.Tensor¶
- value: torch.Tensor¶
- class duo_ai.models.ppo.ImpalaPPOModelConfig[source]¶
Configuration dataclass for ImpalaPPOModel.
- Parameters:
name (str, optional) – Name of the model class. Default is “ImpalaPPOModel”.
Examples
>>> config = ImpalaPPOModelConfig()
- name: str = 'impala_ppo'¶
- class duo_ai.models.ppo.ImpalaPPOModel(config: ImpalaPPOModelConfig, env: gym.Env)[source]¶
Bases:
torch.nn.ModulePPO model using an IMPALA encoder for feature extraction.
Examples
>>> model = ImpalaPPOModel(ImpalaPPOModelConfig(), env) >>> obs = torch.randn(8, 3, 64, 64) >>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- config_cls¶
- device¶
- embedder¶
- fc_policy¶
- fc_value¶
- logit_dim¶
- forward(obs: Any) PPOModelOutput[source]¶
Forward pass of the ImpalaPPOModel.
- Parameters:
obs (torch.Tensor or np.ndarray) – Observation input to the model.
- Returns:
Output container with logits, value, and hidden features.
- Return type:
Examples
>>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- class duo_ai.models.ppo.ImpalaCoordPPOModelConfig[source]¶
Configuration dataclass for ImpalaCoordPPOModel.
- Parameters:
name (str, optional) – Name of the model class. Default is “ImpalaCoordPPOModel”.
feature_type (str, optional) – Type of feature representation to use. Options include: “obs”, “hidden”, “hidden_obs”, “dist”, “hidden_dist”, “obs_dist”, “obs_hidden_dist”. Default is “obs”.
Examples
>>> config = ImpalaCoordPPOModelConfig(feature_type="obs_hidden_dist")
- name: str = 'impala_coord_ppo'¶
- feature_type: str = 'obs'¶
- class duo_ai.models.ppo.ImpalaCoordPPOModel(config: ImpalaCoordPPOModelConfig, env: gym.Env)[source]¶
Bases:
torch.nn.ModulePPO model for coordination environments, supporting multiple feature types.
Examples
>>> model = ImpalaCoordPPOModel(ImpalaCoordPPOModelConfig(), env) >>> obs = {"base_obs": ..., "novice_hidden": ..., "novice_logits": ...} >>> out = model(obs) >>> print(out.logits.shape, out.value.shape)
- config_cls¶
- device¶
- embedder¶
- feature_type¶
- fc_policy¶
- fc_value¶
- logit_dim¶
- forward(obs: Dict[str, Any]) PPOModelOutput[source]¶
Forward pass of the ImpalaCoordPPOModel.
- Parameters:
obs (dict) – Dictionary containing observation components (base_obs, novice_hidden, novice_logits).
- Returns:
Output container with logits, value, and hidden features.
- Return type:
Examples
>>> out = model(obs) >>> print(out.logits.shape, out.value.shape)