[docs]@dataclassclassLogitPolicyConfig:""" Configuration dataclass for LogitPolicy. Parameters ---------- name : str, optional Name of the policy class. Default is "logit". metric : str, optional Confidence metric to use. Default is "max_logit". threshold : float, optional Confidence threshold for expert query. Default is None. temperature : float, optional Temperature for scaling logits. Default is None. load_path : str, optional Path to a checkpoint to load. Default is None. Examples -------- >>> config = LogitPolicyConfig(metric="max_prob", threshold=0.8) """name:str="logit"metric:str="max_logit"threshold:Optional[float]=Nonetemperature:Optional[float]=Noneload_path:Optional[str]=None
[docs]classLogitPolicy(Policy):""" Policy that selects actions based on logit confidence metrics and thresholds. Examples -------- >>> policy = LogitPolicy(LogitPolicyConfig(), env) >>> obs = ... >>> action = policy.act(obs) """config_cls=LogitPolicyConfigdef__init__(self,config:LogitPolicyConfig,env:"gym.Env")->None:""" Initialize the LogitPolicy. Parameters ---------- config : LogitPolicyConfig Configuration object for the policy. env : gym.Env The environment instance, used to determine expert index. Returns ------- None Examples -------- >>> policy = LogitPolicy(LogitPolicyConfig(), env) """self.config=configself.params={"threshold":config.threshold,"temperature":config.temperature}self.device=get_global_variable("device")self.EXPERT=env.EXPERT
[docs]defact(self,obs:Dict[str,Any],temperature:Optional[float]=None)->torch.Tensor:""" Select actions based on confidence scores and threshold. Parameters ---------- obs : dict Observation dictionary containing 'novice_logits'. temperature : float, optional Unused. Included for API compatibility. Returns ------- torch.Tensor Tensor of selected actions (expert or not) for the batch. Examples -------- >>> action = policy.act(obs) """logits=obs["novice_logits"]ifnottorch.is_tensor(logits):logits=torch.from_numpy(logits).to(self.device).float()score=self.compute_confidence(logits)# query expert when confidence score < thresholdaction=torch.where(score<self.params["threshold"],self.EXPERT,1-self.EXPERT,)returnaction
[docs]defcompute_confidence(self,logits:torch.Tensor)->torch.Tensor:""" Compute confidence scores from logits using the configured metric. Parameters ---------- logits : torch.Tensor Logits tensor from the policy. Returns ------- torch.Tensor Confidence scores for each sample in the batch. Raises ------ NotImplementedError If the configured metric is not recognized. Examples -------- >>> score = policy.compute_confidence(logits) """# NOTE: higher = more confidentmetric=self.config.metriclogits=logits/self.params["temperature"]ifmetric=="max_logit":score=logits.max(dim=-1)[0]elifmetric=="max_prob":score=logits.softmax(dim=-1).max(dim=-1)[0]elifmetric=="margin":iflogits.size(-1)>1:# Multi-class casetop2=logits.softmax(dim=-1).topk(2,dim=-1)[0]score=top2[:,0]-top2[:,1]score=scoreelse:# Binary case when logits has shape (B, 1)prob=logits.sigmoid().squeeze(-1)score=torch.abs(2*prob-1)elifmetric=="entropy":# NOTE: we compute NEGATIVE entropy so that higher = more confidentscore=-Categorical(logits=logits).entropy()elifmetric=="energy":score=logits.logsumexp(dim=-1)else:raiseNotImplementedError(f"Unrecognized metric: {metric}")returnscore
[docs]defreset(self,done:"numpy.ndarray")->None:""" Reset the policy state at episode boundaries. Parameters ---------- done : numpy.ndarray Boolean array indicating which episodes in a batch require a reset. Returns ------- None Examples -------- >>> policy.reset(done) """pass
[docs]defget_params(self)->Dict[str,Any]:""" Get the current parameters of the policy. Returns ------- dict Dictionary of policy parameters. Examples -------- >>> params = policy.get_params() """returndc(self.params)
[docs]defset_params(self,params:Dict[str,Any])->None:""" Set the parameters of the policy. Parameters ---------- params : dict Dictionary of policy parameters to set. Returns ------- None Raises ------ KeyError If a parameter key is not recognized by the policy. Examples -------- >>> policy.set_params({'threshold': 0.7}) """fork,vinparams.items():ifknotinself.params:raiseKeyError(f"Parameter {k} not recognized in LogitPolicy")self.params[k]=dc(v)
[docs]deftrain(self)->None:""" Set the policy to training mode. Returns ------- None Examples -------- >>> policy.train() """pass
[docs]defeval(self)->None:""" Set the policy to evaluation mode. Returns ------- None Examples -------- >>> policy.eval() """pass