# mypy: allow-untyped-defs import threading from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpInfo, OpSchema, OpStrategy, OutputSharding, OutputSpecType, PlacementStrategy, RuntimeSchemaInfo, StrategyType, TupleStrategy, ) from torch.distributed.tensor._utils import ( compute_local_shape, compute_local_stride, try_find_mesh_from_args, ) aten = torch.ops.aten def _length(obj) -> int: if obj is None: return 0 if not isinstance(obj, Sequence): return 1 return len(obj) class LocalLRUCache(threading.local): def __init__(self, user_function: Callable) -> None: self.cache = lru_cache(None)(user_function) def __call__(self, *args, **kwargs) -> object: return self.cache(*args, **kwargs) def cache_info(self): return self.cache.cache_info() class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} self.op_strategy_funcs: Dict[ OpOverload, Callable[[DeviceMesh, OpSchema], StrategyType], ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} self.propagate_op_sharding = LocalLRUCache( self.propagate_op_sharding_non_cached ) # op map to save indices of shape (and stride) args which may need to be modified in sharding prop self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] ] = { # new factory ops aten.new_empty.default: 1, aten.new_full.default: 1, aten.new_ones.default: 1, aten.new_zeros.default: 1, aten.new_empty_strided.default: (1, 2), # view ops aten.expand.default: 1, aten.reshape.default: 1, aten.view.default: 1, aten._unsafe_view.default: 1, } def register_sharding_prop_rule( self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding], schema_info: Optional[RuntimeSchemaInfo] = None, ): """ Register a sharding propagation rule for an operator. """ self.op_to_rules[op_overload] = rule_func if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info def register_op_strategy( self, op_overload: OpOverload, strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], schema_info: Optional[RuntimeSchemaInfo] = None, ): """ Register a sharding strategy generator for an operator. """ self.op_strategy_funcs[op_overload] = strategy_func if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info @lru_cache # noqa: B019 def _propagate_tensor_meta( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: """ Propagate the tensor metadata, it could either return a TensorMeta or a list/tuple of TensorMetas """ if op_schema.op == aten.equal.default: # data dependent ops can't be used for fake propagation return None # NOTE: We must call the tracing in fake tensor mode so that it # avoids materializing memory with FakeTensorMode(): fake_args = op_schema.gen_fake_args() fake_kwargs = op_schema.gen_fake_kwargs() fake_out = op_schema.op(*fake_args, **fake_kwargs) if isinstance(fake_out, torch.Tensor): return TensorMeta( shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype ) elif isinstance(fake_out, (tuple, list)): tensor_meta_list: List[Optional[TensorMeta]] = [] for fake_out_item in fake_out: if isinstance(fake_out_item, torch.Tensor): tensor_meta_list.append( TensorMeta( shape=fake_out_item.shape, stride=fake_out_item.stride(), dtype=fake_out_item.dtype, ) ) else: tensor_meta_list.append(None) return ( tuple(tensor_meta_list) if isinstance(fake_out, tuple) else tensor_meta_list ) else: # if fake is not a tensor or tuple of tensor, return as none return None def _wrap_output_spec_tensor_meta( self, op: OpOverload, output_specs: OutputSpecType, output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], ) -> None: """ Wrap the output_specs with the tensor metadata from the output. """ if isinstance(output_specs, DTensorSpec): if not isinstance(output_tensor_meta, TensorMeta): # Either error due to ShardingPropagator or due to incorrect OutputSpec if not isinstance(output_tensor_meta, (tuple, list)): raise ValueError( "ShardingPropagator error: output does not have an associated TensorMeta" ) raise ValueError( f"For the op {op.name()}, `output_specs` has 1 output which does not equal the " f"number of op outputs: {len(output_tensor_meta)}." ) output_specs.tensor_meta = output_tensor_meta elif isinstance(output_specs, (tuple, list)): if not isinstance(output_tensor_meta, (tuple, list)) or len( output_specs ) != len(output_tensor_meta): raise ValueError( f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the " f"number of op outputs {_length(output_tensor_meta)}." ) for i, spec in enumerate(output_specs): if isinstance(spec, DTensorSpec): output_tensor_meta_i = output_tensor_meta[i] if not isinstance(output_tensor_meta_i, TensorMeta): raise ValueError( f"ShardingPropagator error: output {i} does not have an associated TensorMeta" ) spec.tensor_meta = output_tensor_meta_i def propagate(self, op_info: OpInfo) -> None: # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, # and tracing does not need to be as fast as eagermode DTensor usages. if op_info.schema.has_symints: output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: output_sharding = cast( OutputSharding, self.propagate_op_sharding(op_info.schema) ) op_info.output_sharding = output_sharding def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: """ Propagate the sharding for an operator given the op_schema. """ # special case op, we don't need to propagate for local # scalar. TODO: figure out a better way to handle this if op_schema.op is aten._local_scalar_dense.default: return OutputSharding(None, op_schema) out_tensor_meta = self._propagate_tensor_meta(op_schema) def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): return OpStrategy([PlacementStrategy(spec)]) elif ( isinstance(spec, (list, tuple)) and len(spec) > 0 and isinstance(spec[0], DTensorSpec) ): # tensor list create tuple strategy tuple_strategy = [spec_to_strategy(s) for s in spec] tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) return TupleStrategy( tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy ) else: return spec if op_schema.op in self.op_strategy_funcs: # generate op strategy for the op. mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema) # swap the args spec with args strategies args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] kwargs_op_strategy = { k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() } # construct a new OpSchema on args for strategy based propagation strategy_schema: OpSchema = OpSchema( op=op_schema.op, args_schema=tuple(args_op_strategy), kwargs_schema=kwargs_op_strategy, ) op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema) if isinstance(op_strategy, OpStrategy): # single Op strategy output_strategy = self._select_strategy(op_strategy) # check if we need to redistribute the input needs_redistribute = False expected_input_specs: List[DTensorSpec] = [] # in case where the op does not specify input_specs and output_specs # is a DTensorSpec, we use output_specs as the spec for each DTensor # input arg. if output_strategy.input_specs is None: assert isinstance(output_strategy.output_specs, DTensorSpec) for idx, input_spec in enumerate(op_schema.args_spec): desired_spec = ( output_strategy.output_spec if output_strategy.input_specs is None else output_strategy.input_specs[idx] ) expected_input_specs.append( desired_spec.shallow_copy_with_tensor_meta( input_spec.tensor_meta ) ) if input_spec.placements != desired_spec.placements: needs_redistribute = True suggestion_schema = None if needs_redistribute: suggestion_schema = OpSchema( op_schema.op, tuple(expected_input_specs), {} ) suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) # shape and stride args need to be modified for # view ops and new factory ops, potentially if op_schema.op in self.op_to_shape_and_stride_idx: assert isinstance(output_strategy.output_spec, DTensorSpec) # It happens when the output has the same shape as the input # and the input placements are not all Replicate(). if output_strategy.output_spec.is_sharded(): schema = suggestion_schema or op_schema assert isinstance(out_tensor_meta, TensorMeta) suggestion_schema = self._adjust_shape_and_stride_args( out_tensor_meta, schema, output_strategy.output_spec, mesh ) needs_redistribute = True # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): # for ops that return multiple tensors and the output_specs is not # a tuple, we use a tuple of that single output spec as the new # output_specs output_specs: OutputSpecType = output_strategy.output_specs if isinstance(output_specs, DTensorSpec): output_specs = tuple( [ # create a new DTensorSpec with the same placement as the # output_specs in output_strategy DTensorSpec( mesh=output_specs.mesh, placements=output_specs.placements, tensor_meta=output_specs.tensor_meta, ) for _ in range(len(op_schema.op._schema.returns)) ] ) elif op_schema.return_type_tensor(): output_specs = output_strategy.output_specs else: output_specs = None output_sharding = OutputSharding( output_specs, suggestion_schema, needs_redistribute=needs_redistribute, ) elif isinstance(op_strategy, TupleStrategy): # tuple strategy output sharding processing # runtime selected placement strategy for each TupleStrategy input arg selected_strategies: List[PlacementStrategy] = [] out_spec_list: List[DTensorSpec] = [] for strategy in op_strategy.childs: assert isinstance(strategy, OpStrategy) selected_strategy = self._select_strategy(strategy) selected_strategies.append(selected_strategy) out_spec_list.append(selected_strategy.output_spec) needs_redistribute = False suggestion_args: List[object] = [] tensor_or_list_tensor_arg_idx = 0 for arg in op_schema.args_schema: if ( arg and isinstance(arg, (list, tuple)) and isinstance(arg[0], DTensorSpec) ): expected_input_spec_list: List[DTensorSpec] = [] for idx, arg_spec in enumerate(arg): expected_input_spec = selected_strategies[idx].input_spec( tensor_or_list_tensor_arg_idx ) expected_input_spec = ( expected_input_spec.shallow_copy_with_tensor_meta( arg_spec.tensor_meta ) ) if arg_spec.placements != expected_input_spec.placements: needs_redistribute = True expected_input_spec_list.append(expected_input_spec) suggestion_args.append( tuple(expected_input_spec_list) if isinstance(arg, tuple) else expected_input_spec_list ) tensor_or_list_tensor_arg_idx += 1 elif isinstance(arg, DTensorSpec): expected_input_spec = selected_strategies[0].input_spec( tensor_or_list_tensor_arg_idx ) expected_input_spec = ( expected_input_spec.shallow_copy_with_tensor_meta( arg.tensor_meta ) ) if arg.placements != expected_input_spec.placements: needs_redistribute = True suggestion_args.append(expected_input_spec) tensor_or_list_tensor_arg_idx += 1 else: suggestion_args.append(arg) suggestion_schema = None if needs_redistribute: suggestion_schema = OpSchema( op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema ) output_sharding = OutputSharding( tuple(out_spec_list) if out_tensor_meta is not None else None, suggestion_schema, needs_redistribute=needs_redistribute, ) else: raise ValueError("Unsupported op strategy type") # associate the output sharding with the output tensor metadata self._wrap_output_spec_tensor_meta( op_schema.op, output_sharding.output_spec, out_tensor_meta ) return output_sharding elif op_schema.op in self.op_to_rules: # propagate the sharding with rule sharding_prop_func = self.op_to_rules[op_schema.op] # step 1. there's sharding propagation rule, run # sharding propagation to get the output sharding try: output_sharding = sharding_prop_func(op_schema) except NotImplementedError as e: raise e except Exception as e: raise RuntimeError( f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}" ) from e # step 2. if can't get output_spec from sharding # propagation (i.e. no rules apply for input # placements), we return the output sharding # with schema suggestions, which can be used to # decide how to do redistribute on inputs if output_sharding.output_spec is None: if output_sharding.redistribute_schema is None: raise RuntimeError( f"Sharding propagation failed on op {op_schema}!" ) else: # we do auto redistribute on inputs if necessary # run sharding propagation again with suggested schema propagation_res = sharding_prop_func( output_sharding.redistribute_schema ) # we set the output sharding with the new propagation result # so that dispatching know both output_spec and redistribute_schema # exist, which indicates a reshard is needed output_sharding.output_spec = propagation_res.output_spec output_sharding.needs_redistribute = True # associate the output sharding with the output tensor metadata self._wrap_output_spec_tensor_meta( op_schema.op, output_sharding.output_spec, out_tensor_meta ) return output_sharding else: raise NotImplementedError( f"Operator {op_schema.op} does not have a sharding strategy registered." ) def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy: if len(strategy.strategies) == 1: # short cut with only one possible strategy return strategy.strategies[0] strategy_costs: List[float] = [] for strtg in strategy.strategies: assert ( strtg.redistribute_cost is not None ), "must set redistribute cost each strategy!" redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost)) strategy_costs.append(redistribute_cost) # for eager execution, we just select the one with the minimal redistribute cost return strategy.strategies[strategy_costs.index(min(strategy_costs))] def _adjust_shape_and_stride_args( self, out_tensor_meta: TensorMeta, schema: OpSchema, spec: DTensorSpec, mesh: DeviceMesh, ) -> OpSchema: shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] if isinstance(shape_stride_idx, tuple): shape_idx, stride_idx = shape_stride_idx else: shape_idx = shape_stride_idx stride_idx = None expected_input_schema = list(schema.args_schema) # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred expected_input_schema[shape_idx] = compute_local_shape( out_tensor_meta.shape, mesh, spec.placements ) # adjust the stride arg for aten.new_empty_strided.default if stride_idx: expected_input_schema[stride_idx] = compute_local_stride( out_tensor_meta.stride, mesh, spec.placements ) return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)