1import json 2import os 3from functools import partial 4from typing import Any, Callable, Dict, List, Optional 5 6import torch 7from torch._inductor.autoheuristic.autoheuristic_utils import ( 8 AHContext, 9 AHMetadata, 10 AHOperation, 11 Choice, 12 CHOICE_COL, 13 Feedback, 14 FEEDBACK_COL, 15 get_metadata_str_from_log, 16) 17from torch._inductor.autoheuristic.learned_heuristic_controller import ( 18 LearnedHeuristicController, 19) 20from torch._inductor.ir import ChoiceCaller 21from torch._inductor.runtime.runtime_utils import cache_dir 22from torch._inductor.utils import get_gpu_shared_memory 23 24 25class LocalFeedback: 26 """ 27 To be able to collect data for a choice, a function providing feedback given a choice has to be provided. 28 LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice 29 (see pad_mm.py, where the autotuning happens locally, for an example). 30 """ 31 32 def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None: 33 self.feedback_fn = feedback_fn 34 35 def __call__(self, choice: Choice) -> Feedback: 36 return self.feedback_fn(choice) 37 38 39class InconsistentMetadata(Exception): 40 """ 41 Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does 42 not match the metadata it would store if the file didn't exist. 43 """ 44 45 46class AutoHeuristic: 47 """ 48 AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and 49 generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train 50 a heuristic (see torchgen/autoheuristic/). 51 """ 52 53 collected_feedback: Dict[Choice, Feedback] 54 55 def __init__( 56 self, 57 fallback: Callable[[], Choice], 58 choices: List[Choice], 59 feedback: Optional[LocalFeedback], 60 context: AHContext, 61 name: str, 62 augment_context: Optional[List[AHOperation]] = None, 63 precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, 64 ) -> None: 65 """ 66 Initializes an instance of the AutoHeuristic class. 67 68 Args: 69 fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or 70 AutoHeuristic is in data collection mode. 71 choices: A list of possible choices the heuristic can make. 72 feedback: An instance of LocalFeedback that provides feedback for a given choice. 73 context: Context to store with each choice and feedback. 74 name: A string that identifies the heuristic. 75 augment_context: An optional list of AHOperation instances that augment the context. 76 precondition: A callable that returns a boolean indicating whether AutoHeuristic should run. 77 """ 78 self.fallback = fallback 79 self.choices = choices 80 self.feedback = feedback 81 self.context = context 82 self.name = name 83 self.collected_feedback = {} 84 self.augment_context = augment_context 85 self.metadata = AHMetadata( 86 get_gpu_shared_memory(), 87 torch.cuda.get_device_capability(), 88 self.choices, 89 self.name, 90 ) 91 self.precondition = precondition 92 93 if not self.satisfies_precondition(): 94 return 95 96 if torch._inductor.config.autoheuristic_log_path == "DEFAULT": 97 self.log_path = self.get_default_log_path() 98 else: 99 self.log_path = torch._inductor.config.autoheuristic_log_path 100 101 if torch._inductor.config.collect_autoheuristic(self.name): 102 if self.feedback is not None: 103 for choice in self.choices: 104 feedback_val = self.feedback(choice) 105 self.save_data(choice, feedback_val) 106 107 def satisfies_precondition(self) -> bool: 108 return self.precondition is None or self.precondition( 109 self.metadata, self.context 110 ) 111 112 def get_choice(self) -> Choice: 113 """ 114 Returns the chosen option based on the value of autoheuristic_use. 115 If self.name is one of the comma separated strings in autoheuristic_use, 116 it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option. 117 """ 118 119 if not self.satisfies_precondition(): 120 return self.fallback() 121 122 if torch._inductor.config.use_autoheuristic(self.name): 123 if self.augment_context is not None: 124 self.context.apply_operations(self.augment_context) 125 controller = LearnedHeuristicController( 126 self.metadata, 127 self.context, 128 ) 129 decision = controller.get_decision() 130 if decision not in self.choices: 131 # TODO(AlnisM): We might want to allow this in the future 132 return self.fallback() 133 if decision is not None: 134 return decision 135 return self.fallback() 136 137 def get_top_k_choices( 138 self, top_k: int, always_included: Optional[List[str]] = None 139 ) -> Optional[List[Choice]]: 140 if not self.satisfies_precondition(): 141 return None 142 if torch._inductor.config.use_autoheuristic(self.name): 143 if self.augment_context is not None: 144 self.context.apply_operations(self.augment_context) 145 controller = LearnedHeuristicController( 146 self.metadata, 147 self.context, 148 ) 149 choices = controller.get_decisions_ranked(top_k) 150 if choices is None: 151 return None 152 if always_included is not None: 153 for choice in always_included: 154 if choice not in choices: 155 choices.append(choice) 156 return choices 157 return None 158 159 def get_collected_feedback(self, choice: Choice) -> Any: 160 return self.collected_feedback.get(choice, None) 161 162 @staticmethod 163 def get_device_identifier() -> str: 164 # a heuristic might work well for one GPU, but not for another 165 # we store the collected data per GPU model and learn a heuristic per GPU model 166 167 # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names 168 device_name = torch.cuda.get_device_name().replace(" ", "_") 169 return device_name 170 171 def get_default_log_path(self) -> str: 172 device_name = self.get_device_identifier() 173 path = f"{cache_dir()}/autoheuristic/{device_name}/" 174 os.makedirs(path, exist_ok=True) 175 path += f"{self.name}.txt" 176 return path 177 178 def serialize_metadata(self) -> str: 179 metadata_dict = self.metadata.to_dict() 180 ( 181 num_features, 182 cat_features, 183 ) = self.context.get_numerical_and_categorical_features() 184 metadata_dict["numerical_features"] = num_features 185 metadata_dict["categorical_features"] = cat_features 186 return json.dumps(metadata_dict) 187 188 def save_data(self, choice: Choice, feedback_val: Feedback) -> None: 189 self.collected_feedback[choice] = feedback_val 190 log_path = self.log_path 191 192 lines = [] 193 log_exists = os.path.exists(log_path) 194 if log_exists: 195 # if log already exists, make sure it is consistent 196 metadata = self.serialize_metadata() 197 existing_metadata = get_metadata_str_from_log(self.log_path) 198 if existing_metadata != metadata: 199 raise InconsistentMetadata( 200 "Given metadata does not match existing metadata" 201 ) 202 else: 203 lines.append(self.serialize_metadata()) 204 feature_header = self.context.get_feature_names_csv() 205 header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL 206 lines.append(header) 207 208 line = "" 209 feature_values = self.context.get_feature_values_csv() 210 line += feature_values + "," + choice + "," + str(feedback_val) 211 lines.append(line) 212 213 with open(log_path, "a") as f: 214 f.write("\n".join(lines) + "\n") 215 216 217class AutoHeuristicSelectAlgorithm(AutoHeuristic): 218 """ 219 AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic 220 when one wants to use AutoHeuristic for kernel choice selection. 221 """ 222 223 def __init__( 224 self, 225 fallback: Callable[[], Optional[ChoiceCaller]], 226 choices: List[ChoiceCaller], 227 input_nodes: List[Any], 228 context: AHContext, 229 name: str, 230 augment_context: Optional[List[AHOperation]] = None, 231 precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, 232 ) -> None: 233 """ 234 The arguments choices, input_nodes and name have to match the ones used in the call to 235 autotune_select_algorithm(), e.g. if the following call is made 236 autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes 237 have to be used here. 238 """ 239 self.input_nodes = input_nodes 240 self.choicestr2choice: Dict[str, ChoiceCaller] = {} 241 for choice in choices: 242 self.choicestr2choice[choice.autoheuristic_id()] = choice 243 choices_str = list(self.choicestr2choice.keys()) 244 245 def fallback_str() -> str: 246 fallback_choice = fallback() 247 if fallback_choice is None: 248 # TODO: Find a nicer way to handle this 249 return "unsure" 250 return fallback_choice.autoheuristic_id() 251 252 super().__init__( 253 fallback_str, 254 choices_str, 255 None, 256 context, 257 name, 258 augment_context, 259 precondition, 260 ) 261 262 if ( 263 torch._inductor.config.collect_autoheuristic(self.name) 264 and self.satisfies_precondition() 265 ): 266 self.register_global_feedback(input_nodes, choices) 267 268 def register_global_feedback( 269 self, input_nodes: List[Any], choices: List[ChoiceCaller] 270 ) -> None: 271 """ 272 Registers a callback in select_algorithm, which is called with the timing of each choice. 273 """ 274 275 from torch._inductor.select_algorithm import ( 276 add_feedback_saver, 277 create_inputs_key, 278 create_precompile_key, 279 ) 280 281 def store_global_feedback( 282 ah_inputs_key: str, 283 ah_precompile_key: str, 284 timings: Dict[ChoiceCaller, float], 285 name: str, 286 input_nodes: List[Any], 287 choices: List[ChoiceCaller], 288 ) -> None: 289 current_inputs_key = create_inputs_key(input_nodes) 290 if current_inputs_key != ah_inputs_key: 291 return 292 current_precompile_key = create_precompile_key( 293 name, current_inputs_key, choices 294 ) 295 if current_precompile_key != ah_precompile_key: 296 return 297 for choice, time in timings.items(): 298 self.save_data(choice.autoheuristic_id(), time) 299 300 inputs_key = create_inputs_key(input_nodes) 301 precompile_key = create_precompile_key(self.name, inputs_key, choices) 302 feedback_saver = partial(store_global_feedback, inputs_key, precompile_key) 303 add_feedback_saver(feedback_saver) 304 305 def get_choice_caller(self) -> Optional[ChoiceCaller]: 306 choice = self.get_choice() 307 return self.choicestr2choice.get(choice, None) 308 309 def get_top_k_choices_caller( 310 self, top_k: int, always_included: Optional[List[str]] = None 311 ) -> Optional[List[ChoiceCaller]]: 312 choices = self.get_top_k_choices(top_k, always_included) 313 if choices is None: 314 return None 315 return [self.choicestr2choice[choice] for choice in choices] 316