duo_ai.policies.always ====================== .. py:module:: duo_ai.policies.always Classes ------- .. autoapisummary:: duo_ai.policies.always.AlwaysPolicyConfig duo_ai.policies.always.AlwaysPolicy Module Contents --------------- .. py:class:: AlwaysPolicyConfig Configuration dataclass for AlwaysPolicy. :param name: Name of the policy class. Default is "always". :type name: str, optional :param agent: The agent type to always select. Options are "novice" or "expert". Default is "novice". :type agent: str, optional :param load_path: Path to a checkpoint to load. Default is None. :type load_path: str, optional .. rubric:: Examples >>> config = AlwaysPolicyConfig(agent="expert") .. py:attribute:: name :type: str :value: 'always' .. py:attribute:: agent :type: str :value: 'novice' .. py:attribute:: load_path :type: Optional[str] :value: None .. py:class:: AlwaysPolicy(config: AlwaysPolicyConfig, env: gym.Env) Bases: :py:obj:`duo_ai.core.policy.Policy` Policy that always selects the same agent (novice or expert) for every action. .. rubric:: Examples >>> policy = AlwaysPolicy(AlwaysPolicyConfig(agent="novice"), env) >>> obs = ... >>> action = policy.act(obs) .. py:attribute:: config_cls .. py:attribute:: choice .. py:attribute:: device .. py:attribute:: config .. py:method:: act(obs: Any, temperature: Optional[float] = None) -> torch.Tensor Select the constant action for a batch of observations. :param obs: Batch of observations. If dict, must contain 'base_obs'. :type obs: dict or np.ndarray :param temperature: Unused. Included for API compatibility. :type temperature: float, optional :returns: Tensor of constant actions (agent indices) for the batch. :rtype: torch.Tensor :raises ValueError: If obs is not a dict or numpy array. .. rubric:: Examples >>> action = policy.act(obs) .. py:method:: reset(done: numpy.ndarray) -> None Reset the policy state at episode boundaries. :param done: Boolean array indicating which episodes in a batch require a reset. :type done: np.ndarray :rtype: None .. rubric:: Examples >>> policy.reset(done) .. py:method:: get_params() -> Dict[str, Any] Get the current parameters of the policy. :returns: Dictionary of policy parameters. :rtype: dict .. rubric:: Examples >>> params = policy.get_params() .. py:method:: set_params(params: Dict[str, Any]) -> None Set the parameters of the policy. :param params: Dictionary of policy parameters to set. :type params: dict :rtype: None .. rubric:: Examples >>> policy.set_params(params) .. py:method:: train() -> None Set the policy to training mode. :rtype: None .. rubric:: Examples >>> policy.train() .. py:method:: eval() -> None Set the policy to evaluation mode. :rtype: None .. rubric:: Examples >>> policy.eval()