duo_ai.policies.ppo¶
Classes¶
Configuration dataclass for PPOPolicy. |
|
Policy class for PPO, wrapping a model and providing action selection and parameter management. |
Module Contents¶
- class duo_ai.policies.ppo.PPOPolicyConfig[source]¶
Configuration dataclass for PPOPolicy.
- Parameters:
name (str, optional) – Name of the policy class. Default is “ppo”.
model (Any, optional) – Model configuration or class name. Default is “impala_coord_ppo”.
load_path (Optional[str], optional) – Path to a checkpoint to load the policy weights from. Default is None.
Examples
>>> config = PPOPolicyConfig(model="impala_coord_ppo")
- name: str = 'ppo'¶
- model: Any = 'impala_coord_ppo'¶
- load_path: str | None = None¶
- __post_init__() None[source]¶
Post-initialization logic for PPOPolicyConfig.
Converts string or dictionary model fields into their respective configuration objects.
- Raises:
IndexError – If required keys are missing in configuration dictionaries.
ValueError – If model is not a string or a dictionary.
Examples
>>> config = PPOPolicyConfig(model="impala_coord_ppo")
- class duo_ai.policies.ppo.PPOPolicy(config: PPOPolicyConfig, env: gym.Env)[source]¶
Bases:
duo_ai.core.policy.PolicyPolicy class for PPO, wrapping a model and providing action selection and parameter management.
Examples
>>> policy = PPOPolicy(PPOPolicyConfig(), env) >>> obs = ... >>> action = policy.act(obs)
- config_cls¶
- model¶
- config¶
- reset(done: numpy.ndarray) None[source]¶
Reset the policy state at episode boundaries.
- Parameters:
done (numpy.ndarray) – Boolean array indicating which episodes in a batch require a reset.
- Return type:
None
Examples
>>> policy.reset(done)
- act(obs: Any, temperature: float = 1.0, return_model_output: bool = False) Any[source]¶
Select an action based on the observation and temperature.
- Parameters:
obs (Any) – Observation input to the policy.
temperature (float, optional) – Sampling temperature. If 0, selects the argmax action. Default is 1.0.
return_model_output (bool, optional) – If True, also return the model output. Default is False.
- Returns:
action – Selected action, or (action, model_output) if return_model_output is True.
- Return type:
torch.Tensor or tuple
Examples
>>> action = policy.act(obs) >>> action, model_output = policy.act(obs, return_model_output=True)
- set_params(params: Dict[str, Any]) None[source]¶
Set the model parameters from a state dictionary.
- Parameters:
params (dict) – State dictionary of model parameters.
- Return type:
None
Examples
>>> policy.set_params(params)
- get_params() Dict[str, Any][source]¶
Get the current model parameters as a state dictionary.
- Returns:
State dictionary of model parameters.
- Return type:
dict
Examples
>>> params = policy.get_params()