duo_ai.policies.logit ===================== .. py:module:: duo_ai.policies.logit Classes ------- .. autoapisummary:: duo_ai.policies.logit.LogitPolicyConfig duo_ai.policies.logit.LogitPolicy Module Contents --------------- .. py:class:: LogitPolicyConfig Configuration dataclass for LogitPolicy. :param name: Name of the policy class. Default is "logit". :type name: str, optional :param metric: Confidence metric to use. Default is "max_logit". :type metric: str, optional :param threshold: Confidence threshold for expert query. Default is None. :type threshold: float, optional :param temperature: Temperature for scaling logits. Default is None. :type temperature: float, optional :param load_path: Path to a checkpoint to load. Default is None. :type load_path: str, optional .. rubric:: Examples >>> config = LogitPolicyConfig(metric="max_prob", threshold=0.8) .. py:attribute:: name :type: str :value: 'logit' .. py:attribute:: metric :type: str :value: 'max_logit' .. py:attribute:: threshold :type: Optional[float] :value: None .. py:attribute:: temperature :type: Optional[float] :value: None .. py:attribute:: load_path :type: Optional[str] :value: None .. py:class:: LogitPolicy(config: LogitPolicyConfig, env: gym.Env) Bases: :py:obj:`duo_ai.core.policy.Policy` Policy that selects actions based on logit confidence metrics and thresholds. .. rubric:: Examples >>> policy = LogitPolicy(LogitPolicyConfig(), env) >>> obs = ... >>> action = policy.act(obs) .. py:attribute:: config_cls .. py:attribute:: config .. py:attribute:: params .. py:attribute:: device .. py:attribute:: EXPERT .. py:method:: act(obs: Dict[str, Any], temperature: Optional[float] = None) -> torch.Tensor Select actions based on confidence scores and threshold. :param obs: Observation dictionary containing 'novice_logits'. :type obs: dict :param temperature: Unused. Included for API compatibility. :type temperature: float, optional :returns: Tensor of selected actions (expert or not) for the batch. :rtype: torch.Tensor .. rubric:: Examples >>> action = policy.act(obs) .. py:method:: compute_confidence(logits: torch.Tensor) -> torch.Tensor Compute confidence scores from logits using the configured metric. :param logits: Logits tensor from the policy. :type logits: torch.Tensor :returns: Confidence scores for each sample in the batch. :rtype: torch.Tensor :raises NotImplementedError: If the configured metric is not recognized. .. rubric:: Examples >>> score = policy.compute_confidence(logits) .. py:method:: reset(done: numpy.ndarray) -> None Reset the policy state at episode boundaries. :param done: Boolean array indicating which episodes in a batch require a reset. :type done: numpy.ndarray :rtype: None .. rubric:: Examples >>> policy.reset(done) .. py:method:: get_params() -> Dict[str, Any] Get the current parameters of the policy. :returns: Dictionary of policy parameters. :rtype: dict .. rubric:: Examples >>> params = policy.get_params() .. py:method:: set_params(params: Dict[str, Any]) -> None Set the parameters of the policy. :param params: Dictionary of policy parameters to set. :type params: dict :rtype: None :raises KeyError: If a parameter key is not recognized by the policy. .. rubric:: Examples >>> policy.set_params({'threshold': 0.7}) .. py:method:: train() -> None Set the policy to training mode. :rtype: None .. rubric:: Examples >>> policy.train() .. py:method:: eval() -> None Set the policy to evaluation mode. :rtype: None .. rubric:: Examples >>> policy.eval()