duo_ai.policies.logit

Classes

LogitPolicyConfig

Configuration dataclass for LogitPolicy.

LogitPolicy

Policy that selects actions based on logit confidence metrics and thresholds.

Module Contents

class duo_ai.policies.logit.LogitPolicyConfig[source]

Configuration dataclass for LogitPolicy.

Parameters:
  • name (str, optional) – Name of the policy class. Default is “logit”.

  • metric (str, optional) – Confidence metric to use. Default is “max_logit”.

  • threshold (float, optional) – Confidence threshold for expert query. Default is None.

  • temperature (float, optional) – Temperature for scaling logits. Default is None.

  • load_path (str, optional) – Path to a checkpoint to load. Default is None.

Examples

>>> config = LogitPolicyConfig(metric="max_prob", threshold=0.8)
name: str = 'logit'
metric: str = 'max_logit'
threshold: float | None = None
temperature: float | None = None
load_path: str | None = None
class duo_ai.policies.logit.LogitPolicy(config: LogitPolicyConfig, env: gym.Env)[source]

Bases: duo_ai.core.policy.Policy

Policy that selects actions based on logit confidence metrics and thresholds.

Examples

>>> policy = LogitPolicy(LogitPolicyConfig(), env)
>>> obs = ...
>>> action = policy.act(obs)
config_cls
config
params
device
EXPERT
act(obs: Dict[str, Any], temperature: float | None = None) torch.Tensor[source]

Select actions based on confidence scores and threshold.

Parameters:
  • obs (dict) – Observation dictionary containing ‘novice_logits’.

  • temperature (float, optional) – Unused. Included for API compatibility.

Returns:

Tensor of selected actions (expert or not) for the batch.

Return type:

torch.Tensor

Examples

>>> action = policy.act(obs)
compute_confidence(logits: torch.Tensor) torch.Tensor[source]

Compute confidence scores from logits using the configured metric.

Parameters:

logits (torch.Tensor) – Logits tensor from the policy.

Returns:

Confidence scores for each sample in the batch.

Return type:

torch.Tensor

Raises:

NotImplementedError – If the configured metric is not recognized.

Examples

>>> score = policy.compute_confidence(logits)
reset(done: numpy.ndarray) None[source]

Reset the policy state at episode boundaries.

Parameters:

done (numpy.ndarray) – Boolean array indicating which episodes in a batch require a reset.

Return type:

None

Examples

>>> policy.reset(done)
get_params() Dict[str, Any][source]

Get the current parameters of the policy.

Returns:

Dictionary of policy parameters.

Return type:

dict

Examples

>>> params = policy.get_params()
set_params(params: Dict[str, Any]) None[source]

Set the parameters of the policy.

Parameters:

params (dict) – Dictionary of policy parameters to set.

Return type:

None

Raises:

KeyError – If a parameter key is not recognized by the policy.

Examples

>>> policy.set_params({'threshold': 0.7})
train() None[source]

Set the policy to training mode.

Return type:

None

Examples

>>> policy.train()
eval() None[source]

Set the policy to evaluation mode.

Return type:

None

Examples

>>> policy.eval()