• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from typing import List, Optional, Tuple
2
3from torch._inductor.autoheuristic.autoheuristic_utils import (
4    AHContext,
5    AHMetadata,
6    Choice,
7)
8
9
10class LearnedHeuristic:
11    """
12    LearnedHeuristic is a base class for all learned heuristics.
13    """
14
15    def __init__(self) -> None:
16        pass
17
18    def check_precondition(
19        self,
20        metadata: AHMetadata,
21        context: AHContext,
22    ) -> bool:
23        return True
24
25    def get_decision(
26        self, context: AHContext, choices: List[Choice]
27    ) -> Optional[Choice]:
28        return None
29
30    def get_confidence_threshold(self) -> float:
31        return 1.0
32
33    def get_name(self) -> str:
34        return ""
35
36    def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
37        return None
38
39
40class LearnedHeuristicRegression(LearnedHeuristic):
41    def __init__(self) -> None:
42        super().__init__()
43
44    def get_feedback(self, context: AHContext, choice: Choice) -> float:
45        return 1.0
46
47    def get_decision(
48        self, context: AHContext, choices: List[Choice]
49    ) -> Optional[Choice]:
50        choice2feedback = {}
51        for choice in choices:
52            predicted_feedback = self.get_feedback(context, choice)
53            choice2feedback[choice] = predicted_feedback
54        sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1])
55        highest_feedback = sorted_choices_feedback[-1][1]
56        second_highest_feedback = sorted_choices_feedback[-2][1]
57        if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
58            return sorted_choices_feedback[-1][0]
59        # We are not sure which choice is the best one
60        return None
61
62
63class LearnedHeuristicDecision(LearnedHeuristic):
64    def __init__(self) -> None:
65        super().__init__()
66
67    def get_choice(self, idx: int) -> Optional[str]:
68        return None
69
70    def get_decision(
71        self, context: AHContext, choices: List[Choice]
72    ) -> Optional[Choice]:
73        best_choices = self.get_best_choices(context)
74        if not best_choices:
75            return None
76        (best_choice_proba, best_choice_idx) = best_choices[0]
77        if best_choice_proba <= self.get_confidence_threshold():
78            return None
79        return self.get_choice(best_choice_idx)
80
81    def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
82        feedback_idx_list = self.get_best_choices(context)
83        if feedback_idx_list is None:
84            return None
85        choices = [
86            self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
87        ]
88        choices = [choice for choice in choices if choice is not None]
89        return choices
90
91    def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
92        return []
93