duo_ai.policies.ppo =================== .. py:module:: duo_ai.policies.ppo Classes ------- .. autoapisummary:: duo_ai.policies.ppo.PPOPolicyConfig duo_ai.policies.ppo.PPOPolicy Module Contents --------------- .. py:class:: PPOPolicyConfig Configuration dataclass for PPOPolicy. :param name: Name of the policy class. Default is "ppo". :type name: str, optional :param model: Model configuration or class name. Default is "impala_coord_ppo". :type model: Any, optional :param load_path: Path to a checkpoint to load the policy weights from. Default is None. :type load_path: Optional[str], optional .. rubric:: Examples >>> config = PPOPolicyConfig(model="impala_coord_ppo") .. py:attribute:: name :type: str :value: 'ppo' .. py:attribute:: model :type: Any :value: 'impala_coord_ppo' .. py:attribute:: load_path :type: Optional[str] :value: None .. py:method:: __post_init__() -> None Post-initialization logic for PPOPolicyConfig. Converts string or dictionary model fields into their respective configuration objects. :raises IndexError: If required keys are missing in configuration dictionaries. :raises ValueError: If model is not a string or a dictionary. .. rubric:: Examples >>> config = PPOPolicyConfig(model="impala_coord_ppo") .. py:class:: PPOPolicy(config: PPOPolicyConfig, env: gym.Env) Bases: :py:obj:`duo_ai.core.policy.Policy` Policy class for PPO, wrapping a model and providing action selection and parameter management. .. rubric:: Examples >>> policy = PPOPolicy(PPOPolicyConfig(), env) >>> obs = ... >>> action = policy.act(obs) .. py:attribute:: config_cls .. py:attribute:: model .. py:attribute:: config .. py:method:: reset(done: numpy.ndarray) -> None Reset the policy state at episode boundaries. :param done: Boolean array indicating which episodes in a batch require a reset. :type done: numpy.ndarray :rtype: None .. rubric:: Examples >>> policy.reset(done) .. py:method:: act(obs: Any, temperature: float = 1.0, return_model_output: bool = False) -> Any Select an action based on the observation and temperature. :param obs: Observation input to the policy. :type obs: Any :param temperature: Sampling temperature. If 0, selects the argmax action. Default is 1.0. :type temperature: float, optional :param return_model_output: If True, also return the model output. Default is False. :type return_model_output: bool, optional :returns: **action** -- Selected action, or (action, model_output) if return_model_output is True. :rtype: torch.Tensor or tuple .. rubric:: Examples >>> action = policy.act(obs) >>> action, model_output = policy.act(obs, return_model_output=True) .. py:method:: set_params(params: Dict[str, Any]) -> None Set the model parameters from a state dictionary. :param params: State dictionary of model parameters. :type params: dict :rtype: None .. rubric:: Examples >>> policy.set_params(params) .. py:method:: get_params() -> Dict[str, Any] Get the current model parameters as a state dictionary. :returns: State dictionary of model parameters. :rtype: dict .. rubric:: Examples >>> params = policy.get_params() .. py:method:: train() -> None Set the policy/model to training mode. :rtype: None .. rubric:: Examples >>> policy.train() .. py:method:: eval() -> None Set the policy/model to evaluation mode. :rtype: None .. rubric:: Examples >>> policy.eval()