1# mypy: allow-untyped-defs 2import threading 3from functools import lru_cache 4from itertools import chain 5from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union 6 7import torch 8from torch._ops import OpOverload 9from torch._subclasses import FakeTensorMode 10from torch.distributed.device_mesh import DeviceMesh 11from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 12from torch.distributed.tensor._op_schema import ( 13 OpInfo, 14 OpSchema, 15 OpStrategy, 16 OutputSharding, 17 OutputSpecType, 18 PlacementStrategy, 19 RuntimeSchemaInfo, 20 StrategyType, 21 TupleStrategy, 22) 23from torch.distributed.tensor._utils import ( 24 compute_local_shape, 25 compute_local_stride, 26 try_find_mesh_from_args, 27) 28 29 30aten = torch.ops.aten 31 32 33def _length(obj) -> int: 34 if obj is None: 35 return 0 36 if not isinstance(obj, Sequence): 37 return 1 38 return len(obj) 39 40 41class LocalLRUCache(threading.local): 42 def __init__(self, user_function: Callable) -> None: 43 self.cache = lru_cache(None)(user_function) 44 45 def __call__(self, *args, **kwargs) -> object: 46 return self.cache(*args, **kwargs) 47 48 def cache_info(self): 49 return self.cache.cache_info() 50 51 52class ShardingPropagator: 53 def __init__(self) -> None: 54 self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} 55 self.op_strategy_funcs: Dict[ 56 OpOverload, 57 Callable[[DeviceMesh, OpSchema], StrategyType], 58 ] = {} 59 # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop 60 self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} 61 self.propagate_op_sharding = LocalLRUCache( 62 self.propagate_op_sharding_non_cached 63 ) 64 # op map to save indices of shape (and stride) args which may need to be modified in sharding prop 65 self.op_to_shape_and_stride_idx: Dict[ 66 OpOverload, Union[int, Tuple[int, int]] 67 ] = { 68 # new factory ops 69 aten.new_empty.default: 1, 70 aten.new_full.default: 1, 71 aten.new_ones.default: 1, 72 aten.new_zeros.default: 1, 73 aten.new_empty_strided.default: (1, 2), 74 # view ops 75 aten.expand.default: 1, 76 aten.reshape.default: 1, 77 aten.view.default: 1, 78 aten._unsafe_view.default: 1, 79 } 80 81 def register_sharding_prop_rule( 82 self, 83 op_overload: OpOverload, 84 rule_func: Callable[[OpSchema], OutputSharding], 85 schema_info: Optional[RuntimeSchemaInfo] = None, 86 ): 87 """ 88 Register a sharding propagation rule for an operator. 89 """ 90 self.op_to_rules[op_overload] = rule_func 91 if schema_info is not None: 92 self.op_to_schema_info[op_overload] = schema_info 93 94 def register_op_strategy( 95 self, 96 op_overload: OpOverload, 97 strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], 98 schema_info: Optional[RuntimeSchemaInfo] = None, 99 ): 100 """ 101 Register a sharding strategy generator for an operator. 102 """ 103 self.op_strategy_funcs[op_overload] = strategy_func 104 if schema_info is not None: 105 self.op_to_schema_info[op_overload] = schema_info 106 107 @lru_cache # noqa: B019 108 def _propagate_tensor_meta( 109 self, op_schema: OpSchema 110 ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: 111 """ 112 Propagate the tensor metadata, it could either return a TensorMeta 113 or a list/tuple of TensorMetas 114 """ 115 if op_schema.op == aten.equal.default: 116 # data dependent ops can't be used for fake propagation 117 return None 118 119 # NOTE: We must call the tracing in fake tensor mode so that it 120 # avoids materializing memory 121 with FakeTensorMode(): 122 fake_args = op_schema.gen_fake_args() 123 fake_kwargs = op_schema.gen_fake_kwargs() 124 fake_out = op_schema.op(*fake_args, **fake_kwargs) 125 126 if isinstance(fake_out, torch.Tensor): 127 return TensorMeta( 128 shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype 129 ) 130 131 elif isinstance(fake_out, (tuple, list)): 132 tensor_meta_list: List[Optional[TensorMeta]] = [] 133 for fake_out_item in fake_out: 134 if isinstance(fake_out_item, torch.Tensor): 135 tensor_meta_list.append( 136 TensorMeta( 137 shape=fake_out_item.shape, 138 stride=fake_out_item.stride(), 139 dtype=fake_out_item.dtype, 140 ) 141 ) 142 else: 143 tensor_meta_list.append(None) 144 return ( 145 tuple(tensor_meta_list) 146 if isinstance(fake_out, tuple) 147 else tensor_meta_list 148 ) 149 else: 150 # if fake is not a tensor or tuple of tensor, return as none 151 return None 152 153 def _wrap_output_spec_tensor_meta( 154 self, 155 op: OpOverload, 156 output_specs: OutputSpecType, 157 output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], 158 ) -> None: 159 """ 160 Wrap the output_specs with the tensor metadata from the output. 161 """ 162 163 if isinstance(output_specs, DTensorSpec): 164 if not isinstance(output_tensor_meta, TensorMeta): 165 # Either error due to ShardingPropagator or due to incorrect OutputSpec 166 if not isinstance(output_tensor_meta, (tuple, list)): 167 raise ValueError( 168 "ShardingPropagator error: output does not have an associated TensorMeta" 169 ) 170 raise ValueError( 171 f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " 172 f"number of op outputs: {len(output_tensor_meta)}." 173 ) 174 output_specs.tensor_meta = output_tensor_meta 175 elif isinstance(output_specs, (tuple, list)): 176 if not isinstance(output_tensor_meta, (tuple, list)) or len( 177 output_specs 178 ) != len(output_tensor_meta): 179 raise ValueError( 180 f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " 181 f"number of op outputs {_length(output_tensor_meta)}." 182 ) 183 for i, spec in enumerate(output_specs): 184 if isinstance(spec, DTensorSpec): 185 output_tensor_meta_i = output_tensor_meta[i] 186 if not isinstance(output_tensor_meta_i, TensorMeta): 187 raise ValueError( 188 f"ShardingPropagator error: output {i} does not have an associated TensorMeta" 189 ) 190 spec.tensor_meta = output_tensor_meta_i 191 192 def propagate(self, op_info: OpInfo) -> None: 193 # We cannot use an lru cache if we know that inputs will have dynamic shapes, 194 # because SymInts are not hashable. 195 # This is generally ok because this only happens during tracing in torch.compile, 196 # and tracing does not need to be as fast as eagermode DTensor usages. 197 if op_info.schema.has_symints: 198 output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) 199 else: 200 output_sharding = cast( 201 OutputSharding, self.propagate_op_sharding(op_info.schema) 202 ) 203 op_info.output_sharding = output_sharding 204 205 def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: 206 """ 207 Propagate the sharding for an operator given the op_schema. 208 """ 209 # special case op, we don't need to propagate for local 210 # scalar. TODO: figure out a better way to handle this 211 if op_schema.op is aten._local_scalar_dense.default: 212 return OutputSharding(None, op_schema) 213 214 out_tensor_meta = self._propagate_tensor_meta(op_schema) 215 216 def spec_to_strategy(spec: object) -> object: 217 if isinstance(spec, DTensorSpec): 218 return OpStrategy([PlacementStrategy(spec)]) 219 elif ( 220 isinstance(spec, (list, tuple)) 221 and len(spec) > 0 222 and isinstance(spec[0], DTensorSpec) 223 ): 224 # tensor list create tuple strategy 225 tuple_strategy = [spec_to_strategy(s) for s in spec] 226 tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) 227 return TupleStrategy( 228 tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy 229 ) 230 else: 231 return spec 232 233 if op_schema.op in self.op_strategy_funcs: 234 # generate op strategy for the op. 235 mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema) 236 # swap the args spec with args strategies 237 args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] 238 239 kwargs_op_strategy = { 240 k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() 241 } 242 243 # construct a new OpSchema on args for strategy based propagation 244 strategy_schema: OpSchema = OpSchema( 245 op=op_schema.op, 246 args_schema=tuple(args_op_strategy), 247 kwargs_schema=kwargs_op_strategy, 248 ) 249 250 op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema) 251 252 if isinstance(op_strategy, OpStrategy): 253 # single Op strategy 254 output_strategy = self._select_strategy(op_strategy) 255 256 # check if we need to redistribute the input 257 needs_redistribute = False 258 expected_input_specs: List[DTensorSpec] = [] 259 260 # in case where the op does not specify input_specs and output_specs 261 # is a DTensorSpec, we use output_specs as the spec for each DTensor 262 # input arg. 263 if output_strategy.input_specs is None: 264 assert isinstance(output_strategy.output_specs, DTensorSpec) 265 266 for idx, input_spec in enumerate(op_schema.args_spec): 267 desired_spec = ( 268 output_strategy.output_spec 269 if output_strategy.input_specs is None 270 else output_strategy.input_specs[idx] 271 ) 272 expected_input_specs.append( 273 desired_spec.shallow_copy_with_tensor_meta( 274 input_spec.tensor_meta 275 ) 276 ) 277 if input_spec.placements != desired_spec.placements: 278 needs_redistribute = True 279 280 suggestion_schema = None 281 if needs_redistribute: 282 suggestion_schema = OpSchema( 283 op_schema.op, tuple(expected_input_specs), {} 284 ) 285 suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) 286 287 # shape and stride args need to be modified for 288 # view ops and new factory ops, potentially 289 if op_schema.op in self.op_to_shape_and_stride_idx: 290 assert isinstance(output_strategy.output_spec, DTensorSpec) 291 # It happens when the output has the same shape as the input 292 # and the input placements are not all Replicate(). 293 if output_strategy.output_spec.is_sharded(): 294 schema = suggestion_schema or op_schema 295 assert isinstance(out_tensor_meta, TensorMeta) 296 suggestion_schema = self._adjust_shape_and_stride_args( 297 out_tensor_meta, schema, output_strategy.output_spec, mesh 298 ) 299 needs_redistribute = True 300 301 # construct output spec for the op 302 if op_schema.return_type_tuple_tensor_like(): 303 # for ops that return multiple tensors and the output_specs is not 304 # a tuple, we use a tuple of that single output spec as the new 305 # output_specs 306 output_specs: OutputSpecType = output_strategy.output_specs 307 if isinstance(output_specs, DTensorSpec): 308 output_specs = tuple( 309 [ 310 # create a new DTensorSpec with the same placement as the 311 # output_specs in output_strategy 312 DTensorSpec( 313 mesh=output_specs.mesh, 314 placements=output_specs.placements, 315 tensor_meta=output_specs.tensor_meta, 316 ) 317 for _ in range(len(op_schema.op._schema.returns)) 318 ] 319 ) 320 elif op_schema.return_type_tensor(): 321 output_specs = output_strategy.output_specs 322 else: 323 output_specs = None 324 325 output_sharding = OutputSharding( 326 output_specs, 327 suggestion_schema, 328 needs_redistribute=needs_redistribute, 329 ) 330 elif isinstance(op_strategy, TupleStrategy): 331 # tuple strategy output sharding processing 332 # runtime selected placement strategy for each TupleStrategy input arg 333 selected_strategies: List[PlacementStrategy] = [] 334 out_spec_list: List[DTensorSpec] = [] 335 for strategy in op_strategy.childs: 336 assert isinstance(strategy, OpStrategy) 337 selected_strategy = self._select_strategy(strategy) 338 selected_strategies.append(selected_strategy) 339 out_spec_list.append(selected_strategy.output_spec) 340 341 needs_redistribute = False 342 suggestion_args: List[object] = [] 343 tensor_or_list_tensor_arg_idx = 0 344 345 for arg in op_schema.args_schema: 346 if ( 347 arg 348 and isinstance(arg, (list, tuple)) 349 and isinstance(arg[0], DTensorSpec) 350 ): 351 expected_input_spec_list: List[DTensorSpec] = [] 352 for idx, arg_spec in enumerate(arg): 353 expected_input_spec = selected_strategies[idx].input_spec( 354 tensor_or_list_tensor_arg_idx 355 ) 356 expected_input_spec = ( 357 expected_input_spec.shallow_copy_with_tensor_meta( 358 arg_spec.tensor_meta 359 ) 360 ) 361 if arg_spec.placements != expected_input_spec.placements: 362 needs_redistribute = True 363 expected_input_spec_list.append(expected_input_spec) 364 suggestion_args.append( 365 tuple(expected_input_spec_list) 366 if isinstance(arg, tuple) 367 else expected_input_spec_list 368 ) 369 tensor_or_list_tensor_arg_idx += 1 370 371 elif isinstance(arg, DTensorSpec): 372 expected_input_spec = selected_strategies[0].input_spec( 373 tensor_or_list_tensor_arg_idx 374 ) 375 expected_input_spec = ( 376 expected_input_spec.shallow_copy_with_tensor_meta( 377 arg.tensor_meta 378 ) 379 ) 380 if arg.placements != expected_input_spec.placements: 381 needs_redistribute = True 382 suggestion_args.append(expected_input_spec) 383 tensor_or_list_tensor_arg_idx += 1 384 else: 385 suggestion_args.append(arg) 386 387 suggestion_schema = None 388 if needs_redistribute: 389 suggestion_schema = OpSchema( 390 op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema 391 ) 392 393 output_sharding = OutputSharding( 394 tuple(out_spec_list) if out_tensor_meta is not None else None, 395 suggestion_schema, 396 needs_redistribute=needs_redistribute, 397 ) 398 else: 399 raise ValueError("Unsupported op strategy type") 400 401 # associate the output sharding with the output tensor metadata 402 self._wrap_output_spec_tensor_meta( 403 op_schema.op, output_sharding.output_spec, out_tensor_meta 404 ) 405 return output_sharding 406 elif op_schema.op in self.op_to_rules: 407 # propagate the sharding with rule 408 sharding_prop_func = self.op_to_rules[op_schema.op] 409 410 # step 1. there's sharding propagation rule, run 411 # sharding propagation to get the output sharding 412 try: 413 output_sharding = sharding_prop_func(op_schema) 414 except NotImplementedError as e: 415 raise e 416 except Exception as e: 417 raise RuntimeError( 418 f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}" 419 ) from e 420 421 # step 2. if can't get output_spec from sharding 422 # propagation (i.e. no rules apply for input 423 # placements), we return the output sharding 424 # with schema suggestions, which can be used to 425 # decide how to do redistribute on inputs 426 if output_sharding.output_spec is None: 427 if output_sharding.redistribute_schema is None: 428 raise RuntimeError( 429 f"Sharding propagation failed on op {op_schema}!" 430 ) 431 else: 432 # we do auto redistribute on inputs if necessary 433 # run sharding propagation again with suggested schema 434 propagation_res = sharding_prop_func( 435 output_sharding.redistribute_schema 436 ) 437 # we set the output sharding with the new propagation result 438 # so that dispatching know both output_spec and redistribute_schema 439 # exist, which indicates a reshard is needed 440 output_sharding.output_spec = propagation_res.output_spec 441 output_sharding.needs_redistribute = True 442 443 # associate the output sharding with the output tensor metadata 444 self._wrap_output_spec_tensor_meta( 445 op_schema.op, output_sharding.output_spec, out_tensor_meta 446 ) 447 448 return output_sharding 449 else: 450 raise NotImplementedError( 451 f"Operator {op_schema.op} does not have a sharding strategy registered." 452 ) 453 454 def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: 455 if len(strategy.strategies) == 1: 456 # short cut with only one possible strategy 457 return strategy.strategies[0] 458 459 strategy_costs: List[float] = [] 460 for strtg in strategy.strategies: 461 assert ( 462 strtg.redistribute_cost is not None 463 ), "must set redistribute cost each strategy!" 464 redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) 465 strategy_costs.append(redistribute_cost) 466 467 # for eager execution, we just select the one with the minimal redistribute cost 468 return strategy.strategies[strategy_costs.index(min(strategy_costs))] 469 470 def _adjust_shape_and_stride_args( 471 self, 472 out_tensor_meta: TensorMeta, 473 schema: OpSchema, 474 spec: DTensorSpec, 475 mesh: DeviceMesh, 476 ) -> OpSchema: 477 shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] 478 if isinstance(shape_stride_idx, tuple): 479 shape_idx, stride_idx = shape_stride_idx 480 else: 481 shape_idx = shape_stride_idx 482 stride_idx = None 483 484 expected_input_schema = list(schema.args_schema) 485 # adjust shape to be the same as that of the _local_tensor 486 # of the DTensor input arg at index 0, which is inferred 487 expected_input_schema[shape_idx] = compute_local_shape( 488 out_tensor_meta.shape, mesh, spec.placements 489 ) 490 491 # adjust the stride arg for aten.new_empty_strided.default 492 if stride_idx: 493 expected_input_schema[stride_idx] = compute_local_stride( 494 out_tensor_meta.stride, mesh, spec.placements 495 ) 496 497 return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) 498