• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import json
2import os
3from functools import partial
4from typing import Any, Callable, Dict, List, Optional
5
6import torch
7from torch._inductor.autoheuristic.autoheuristic_utils import (
8    AHContext,
9    AHMetadata,
10    AHOperation,
11    Choice,
12    CHOICE_COL,
13    Feedback,
14    FEEDBACK_COL,
15    get_metadata_str_from_log,
16)
17from torch._inductor.autoheuristic.learned_heuristic_controller import (
18    LearnedHeuristicController,
19)
20from torch._inductor.ir import ChoiceCaller
21from torch._inductor.runtime.runtime_utils import cache_dir
22from torch._inductor.utils import get_gpu_shared_memory
23
24
25class LocalFeedback:
26    """
27    To be able to collect data for a choice, a function providing feedback given a choice has to be provided.
28    LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice
29    (see pad_mm.py, where the autotuning happens locally, for an example).
30    """
31
32    def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None:
33        self.feedback_fn = feedback_fn
34
35    def __call__(self, choice: Choice) -> Feedback:
36        return self.feedback_fn(choice)
37
38
39class InconsistentMetadata(Exception):
40    """
41    Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does
42    not match the metadata it would store if the file didn't exist.
43    """
44
45
46class AutoHeuristic:
47    """
48    AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and
49    generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train
50    a heuristic (see torchgen/autoheuristic/).
51    """
52
53    collected_feedback: Dict[Choice, Feedback]
54
55    def __init__(
56        self,
57        fallback: Callable[[], Choice],
58        choices: List[Choice],
59        feedback: Optional[LocalFeedback],
60        context: AHContext,
61        name: str,
62        augment_context: Optional[List[AHOperation]] = None,
63        precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
64    ) -> None:
65        """
66        Initializes an instance of the AutoHeuristic class.
67
68        Args:
69            fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or
70            AutoHeuristic is in data collection mode.
71            choices: A list of possible choices the heuristic can make.
72            feedback: An instance of LocalFeedback that provides feedback for a given choice.
73            context: Context to store with each choice and feedback.
74            name: A string that identifies the heuristic.
75            augment_context: An optional list of AHOperation instances that augment the context.
76            precondition: A callable that returns a boolean indicating whether AutoHeuristic should run.
77        """
78        self.fallback = fallback
79        self.choices = choices
80        self.feedback = feedback
81        self.context = context
82        self.name = name
83        self.collected_feedback = {}
84        self.augment_context = augment_context
85        self.metadata = AHMetadata(
86            get_gpu_shared_memory(),
87            torch.cuda.get_device_capability(),
88            self.choices,
89            self.name,
90        )
91        self.precondition = precondition
92
93        if not self.satisfies_precondition():
94            return
95
96        if torch._inductor.config.autoheuristic_log_path == "DEFAULT":
97            self.log_path = self.get_default_log_path()
98        else:
99            self.log_path = torch._inductor.config.autoheuristic_log_path
100
101        if torch._inductor.config.collect_autoheuristic(self.name):
102            if self.feedback is not None:
103                for choice in self.choices:
104                    feedback_val = self.feedback(choice)
105                    self.save_data(choice, feedback_val)
106
107    def satisfies_precondition(self) -> bool:
108        return self.precondition is None or self.precondition(
109            self.metadata, self.context
110        )
111
112    def get_choice(self) -> Choice:
113        """
114        Returns the chosen option based on the value of autoheuristic_use.
115        If self.name is one of the comma separated strings in autoheuristic_use,
116        it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option.
117        """
118
119        if not self.satisfies_precondition():
120            return self.fallback()
121
122        if torch._inductor.config.use_autoheuristic(self.name):
123            if self.augment_context is not None:
124                self.context.apply_operations(self.augment_context)
125            controller = LearnedHeuristicController(
126                self.metadata,
127                self.context,
128            )
129            decision = controller.get_decision()
130            if decision not in self.choices:
131                # TODO(AlnisM): We might want to allow this in the future
132                return self.fallback()
133            if decision is not None:
134                return decision
135        return self.fallback()
136
137    def get_top_k_choices(
138        self, top_k: int, always_included: Optional[List[str]] = None
139    ) -> Optional[List[Choice]]:
140        if not self.satisfies_precondition():
141            return None
142        if torch._inductor.config.use_autoheuristic(self.name):
143            if self.augment_context is not None:
144                self.context.apply_operations(self.augment_context)
145            controller = LearnedHeuristicController(
146                self.metadata,
147                self.context,
148            )
149            choices = controller.get_decisions_ranked(top_k)
150            if choices is None:
151                return None
152            if always_included is not None:
153                for choice in always_included:
154                    if choice not in choices:
155                        choices.append(choice)
156            return choices
157        return None
158
159    def get_collected_feedback(self, choice: Choice) -> Any:
160        return self.collected_feedback.get(choice, None)
161
162    @staticmethod
163    def get_device_identifier() -> str:
164        # a heuristic might work well for one GPU, but not for another
165        # we store the collected data per GPU model and learn a heuristic per GPU model
166
167        # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names
168        device_name = torch.cuda.get_device_name().replace(" ", "_")
169        return device_name
170
171    def get_default_log_path(self) -> str:
172        device_name = self.get_device_identifier()
173        path = f"{cache_dir()}/autoheuristic/{device_name}/"
174        os.makedirs(path, exist_ok=True)
175        path += f"{self.name}.txt"
176        return path
177
178    def serialize_metadata(self) -> str:
179        metadata_dict = self.metadata.to_dict()
180        (
181            num_features,
182            cat_features,
183        ) = self.context.get_numerical_and_categorical_features()
184        metadata_dict["numerical_features"] = num_features
185        metadata_dict["categorical_features"] = cat_features
186        return json.dumps(metadata_dict)
187
188    def save_data(self, choice: Choice, feedback_val: Feedback) -> None:
189        self.collected_feedback[choice] = feedback_val
190        log_path = self.log_path
191
192        lines = []
193        log_exists = os.path.exists(log_path)
194        if log_exists:
195            # if log already exists, make sure it is consistent
196            metadata = self.serialize_metadata()
197            existing_metadata = get_metadata_str_from_log(self.log_path)
198            if existing_metadata != metadata:
199                raise InconsistentMetadata(
200                    "Given metadata does not match existing metadata"
201                )
202        else:
203            lines.append(self.serialize_metadata())
204            feature_header = self.context.get_feature_names_csv()
205            header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL
206            lines.append(header)
207
208        line = ""
209        feature_values = self.context.get_feature_values_csv()
210        line += feature_values + "," + choice + "," + str(feedback_val)
211        lines.append(line)
212
213        with open(log_path, "a") as f:
214            f.write("\n".join(lines) + "\n")
215
216
217class AutoHeuristicSelectAlgorithm(AutoHeuristic):
218    """
219    AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic
220    when one wants to use AutoHeuristic for kernel choice selection.
221    """
222
223    def __init__(
224        self,
225        fallback: Callable[[], Optional[ChoiceCaller]],
226        choices: List[ChoiceCaller],
227        input_nodes: List[Any],
228        context: AHContext,
229        name: str,
230        augment_context: Optional[List[AHOperation]] = None,
231        precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
232    ) -> None:
233        """
234        The arguments choices, input_nodes and name have to match the ones used in the call to
235        autotune_select_algorithm(), e.g. if the following call is made
236        autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes
237        have to be used here.
238        """
239        self.input_nodes = input_nodes
240        self.choicestr2choice: Dict[str, ChoiceCaller] = {}
241        for choice in choices:
242            self.choicestr2choice[choice.autoheuristic_id()] = choice
243        choices_str = list(self.choicestr2choice.keys())
244
245        def fallback_str() -> str:
246            fallback_choice = fallback()
247            if fallback_choice is None:
248                # TODO: Find a nicer way to handle this
249                return "unsure"
250            return fallback_choice.autoheuristic_id()
251
252        super().__init__(
253            fallback_str,
254            choices_str,
255            None,
256            context,
257            name,
258            augment_context,
259            precondition,
260        )
261
262        if (
263            torch._inductor.config.collect_autoheuristic(self.name)
264            and self.satisfies_precondition()
265        ):
266            self.register_global_feedback(input_nodes, choices)
267
268    def register_global_feedback(
269        self, input_nodes: List[Any], choices: List[ChoiceCaller]
270    ) -> None:
271        """
272        Registers a callback in select_algorithm, which is called with the timing of each choice.
273        """
274
275        from torch._inductor.select_algorithm import (
276            add_feedback_saver,
277            create_inputs_key,
278            create_precompile_key,
279        )
280
281        def store_global_feedback(
282            ah_inputs_key: str,
283            ah_precompile_key: str,
284            timings: Dict[ChoiceCaller, float],
285            name: str,
286            input_nodes: List[Any],
287            choices: List[ChoiceCaller],
288        ) -> None:
289            current_inputs_key = create_inputs_key(input_nodes)
290            if current_inputs_key != ah_inputs_key:
291                return
292            current_precompile_key = create_precompile_key(
293                name, current_inputs_key, choices
294            )
295            if current_precompile_key != ah_precompile_key:
296                return
297            for choice, time in timings.items():
298                self.save_data(choice.autoheuristic_id(), time)
299
300        inputs_key = create_inputs_key(input_nodes)
301        precompile_key = create_precompile_key(self.name, inputs_key, choices)
302        feedback_saver = partial(store_global_feedback, inputs_key, precompile_key)
303        add_feedback_saver(feedback_saver)
304
305    def get_choice_caller(self) -> Optional[ChoiceCaller]:
306        choice = self.get_choice()
307        return self.choicestr2choice.get(choice, None)
308
309    def get_top_k_choices_caller(
310        self, top_k: int, always_included: Optional[List[str]] = None
311    ) -> Optional[List[ChoiceCaller]]:
312        choices = self.get_top_k_choices(top_k, always_included)
313        if choices is None:
314            return None
315        return [self.choicestr2choice[choice] for choice in choices]
316