duo_ai.policies.logit¶
Classes¶
Configuration dataclass for LogitPolicy. |
|
Policy that selects actions based on logit confidence metrics and thresholds. |
Module Contents¶
- class duo_ai.policies.logit.LogitPolicyConfig[source]¶
Configuration dataclass for LogitPolicy.
- Parameters:
name (str, optional) – Name of the policy class. Default is “logit”.
metric (str, optional) – Confidence metric to use. Default is “max_logit”.
threshold (float, optional) – Confidence threshold for expert query. Default is None.
temperature (float, optional) – Temperature for scaling logits. Default is None.
load_path (str, optional) – Path to a checkpoint to load. Default is None.
Examples
>>> config = LogitPolicyConfig(metric="max_prob", threshold=0.8)
- name: str = 'logit'¶
- metric: str = 'max_logit'¶
- threshold: float | None = None¶
- temperature: float | None = None¶
- load_path: str | None = None¶
- class duo_ai.policies.logit.LogitPolicy(config: LogitPolicyConfig, env: gym.Env)[source]¶
Bases:
duo_ai.core.policy.PolicyPolicy that selects actions based on logit confidence metrics and thresholds.
Examples
>>> policy = LogitPolicy(LogitPolicyConfig(), env) >>> obs = ... >>> action = policy.act(obs)
- config_cls¶
- config¶
- params¶
- device¶
- EXPERT¶
- act(obs: Dict[str, Any], temperature: float | None = None) torch.Tensor[source]¶
Select actions based on confidence scores and threshold.
- Parameters:
obs (dict) – Observation dictionary containing ‘novice_logits’.
temperature (float, optional) – Unused. Included for API compatibility.
- Returns:
Tensor of selected actions (expert or not) for the batch.
- Return type:
torch.Tensor
Examples
>>> action = policy.act(obs)
- compute_confidence(logits: torch.Tensor) torch.Tensor[source]¶
Compute confidence scores from logits using the configured metric.
- Parameters:
logits (torch.Tensor) – Logits tensor from the policy.
- Returns:
Confidence scores for each sample in the batch.
- Return type:
torch.Tensor
- Raises:
NotImplementedError – If the configured metric is not recognized.
Examples
>>> score = policy.compute_confidence(logits)
- reset(done: numpy.ndarray) None[source]¶
Reset the policy state at episode boundaries.
- Parameters:
done (numpy.ndarray) – Boolean array indicating which episodes in a batch require a reset.
- Return type:
None
Examples
>>> policy.reset(done)
- get_params() Dict[str, Any][source]¶
Get the current parameters of the policy.
- Returns:
Dictionary of policy parameters.
- Return type:
dict
Examples
>>> params = policy.get_params()
- set_params(params: Dict[str, Any]) None[source]¶
Set the parameters of the policy.
- Parameters:
params (dict) – Dictionary of policy parameters to set.
- Return type:
None
- Raises:
KeyError – If a parameter key is not recognized by the policy.
Examples
>>> policy.set_params({'threshold': 0.7})