1from typing import List, Optional, Tuple 2 3from torch._inductor.autoheuristic.autoheuristic_utils import ( 4 AHContext, 5 AHMetadata, 6 Choice, 7) 8 9 10class LearnedHeuristic: 11 """ 12 LearnedHeuristic is a base class for all learned heuristics. 13 """ 14 15 def __init__(self) -> None: 16 pass 17 18 def check_precondition( 19 self, 20 metadata: AHMetadata, 21 context: AHContext, 22 ) -> bool: 23 return True 24 25 def get_decision( 26 self, context: AHContext, choices: List[Choice] 27 ) -> Optional[Choice]: 28 return None 29 30 def get_confidence_threshold(self) -> float: 31 return 1.0 32 33 def get_name(self) -> str: 34 return "" 35 36 def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]: 37 return None 38 39 40class LearnedHeuristicRegression(LearnedHeuristic): 41 def __init__(self) -> None: 42 super().__init__() 43 44 def get_feedback(self, context: AHContext, choice: Choice) -> float: 45 return 1.0 46 47 def get_decision( 48 self, context: AHContext, choices: List[Choice] 49 ) -> Optional[Choice]: 50 choice2feedback = {} 51 for choice in choices: 52 predicted_feedback = self.get_feedback(context, choice) 53 choice2feedback[choice] = predicted_feedback 54 sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1]) 55 highest_feedback = sorted_choices_feedback[-1][1] 56 second_highest_feedback = sorted_choices_feedback[-2][1] 57 if highest_feedback / second_highest_feedback > self.get_confidence_threshold(): 58 return sorted_choices_feedback[-1][0] 59 # We are not sure which choice is the best one 60 return None 61 62 63class LearnedHeuristicDecision(LearnedHeuristic): 64 def __init__(self) -> None: 65 super().__init__() 66 67 def get_choice(self, idx: int) -> Optional[str]: 68 return None 69 70 def get_decision( 71 self, context: AHContext, choices: List[Choice] 72 ) -> Optional[Choice]: 73 best_choices = self.get_best_choices(context) 74 if not best_choices: 75 return None 76 (best_choice_proba, best_choice_idx) = best_choices[0] 77 if best_choice_proba <= self.get_confidence_threshold(): 78 return None 79 return self.get_choice(best_choice_idx) 80 81 def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]: 82 feedback_idx_list = self.get_best_choices(context) 83 if feedback_idx_list is None: 84 return None 85 choices = [ 86 self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list 87 ] 88 choices = [choice for choice in choices if choice is not None] 89 return choices 90 91 def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: 92 return [] 93