• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import threading
3from functools import lru_cache
4from itertools import chain
5from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
6
7import torch
8from torch._ops import OpOverload
9from torch._subclasses import FakeTensorMode
10from torch.distributed.device_mesh import DeviceMesh
11from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
12from torch.distributed.tensor._op_schema import (
13    OpInfo,
14    OpSchema,
15    OpStrategy,
16    OutputSharding,
17    OutputSpecType,
18    PlacementStrategy,
19    RuntimeSchemaInfo,
20    StrategyType,
21    TupleStrategy,
22)
23from torch.distributed.tensor._utils import (
24    compute_local_shape,
25    compute_local_stride,
26    try_find_mesh_from_args,
27)
28
29
30aten = torch.ops.aten
31
32
33def _length(obj) -> int:
34    if obj is None:
35        return 0
36    if not isinstance(obj, Sequence):
37        return 1
38    return len(obj)
39
40
41class LocalLRUCache(threading.local):
42    def __init__(self, user_function: Callable) -> None:
43        self.cache = lru_cache(None)(user_function)
44
45    def __call__(self, *args, **kwargs) -> object:
46        return self.cache(*args, **kwargs)
47
48    def cache_info(self):
49        return self.cache.cache_info()
50
51
52class ShardingPropagator:
53    def __init__(self) -> None:
54        self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
55        self.op_strategy_funcs: Dict[
56            OpOverload,
57            Callable[[DeviceMesh, OpSchema], StrategyType],
58        ] = {}
59        # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
60        self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
61        self.propagate_op_sharding = LocalLRUCache(
62            self.propagate_op_sharding_non_cached
63        )
64        # op map to save indices of shape (and stride) args which may need to be modified in sharding prop
65        self.op_to_shape_and_stride_idx: Dict[
66            OpOverload, Union[int, Tuple[int, int]]
67        ] = {
68            # new factory ops
69            aten.new_empty.default: 1,
70            aten.new_full.default: 1,
71            aten.new_ones.default: 1,
72            aten.new_zeros.default: 1,
73            aten.new_empty_strided.default: (1, 2),
74            # view ops
75            aten.expand.default: 1,
76            aten.reshape.default: 1,
77            aten.view.default: 1,
78            aten._unsafe_view.default: 1,
79        }
80
81    def register_sharding_prop_rule(
82        self,
83        op_overload: OpOverload,
84        rule_func: Callable[[OpSchema], OutputSharding],
85        schema_info: Optional[RuntimeSchemaInfo] = None,
86    ):
87        """
88        Register a sharding propagation rule for an operator.
89        """
90        self.op_to_rules[op_overload] = rule_func
91        if schema_info is not None:
92            self.op_to_schema_info[op_overload] = schema_info
93
94    def register_op_strategy(
95        self,
96        op_overload: OpOverload,
97        strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType],
98        schema_info: Optional[RuntimeSchemaInfo] = None,
99    ):
100        """
101        Register a sharding strategy generator for an operator.
102        """
103        self.op_strategy_funcs[op_overload] = strategy_func
104        if schema_info is not None:
105            self.op_to_schema_info[op_overload] = schema_info
106
107    @lru_cache  # noqa: B019
108    def _propagate_tensor_meta(
109        self, op_schema: OpSchema
110    ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
111        """
112        Propagate the tensor metadata, it could either return a TensorMeta
113        or a list/tuple of TensorMetas
114        """
115        if op_schema.op == aten.equal.default:
116            # data dependent ops can't be used for fake propagation
117            return None
118
119        # NOTE: We must call the tracing in fake tensor mode so that it
120        # avoids materializing memory
121        with FakeTensorMode():
122            fake_args = op_schema.gen_fake_args()
123            fake_kwargs = op_schema.gen_fake_kwargs()
124            fake_out = op_schema.op(*fake_args, **fake_kwargs)
125
126        if isinstance(fake_out, torch.Tensor):
127            return TensorMeta(
128                shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype
129            )
130
131        elif isinstance(fake_out, (tuple, list)):
132            tensor_meta_list: List[Optional[TensorMeta]] = []
133            for fake_out_item in fake_out:
134                if isinstance(fake_out_item, torch.Tensor):
135                    tensor_meta_list.append(
136                        TensorMeta(
137                            shape=fake_out_item.shape,
138                            stride=fake_out_item.stride(),
139                            dtype=fake_out_item.dtype,
140                        )
141                    )
142                else:
143                    tensor_meta_list.append(None)
144            return (
145                tuple(tensor_meta_list)
146                if isinstance(fake_out, tuple)
147                else tensor_meta_list
148            )
149        else:
150            # if fake is not a tensor or tuple of tensor, return as none
151            return None
152
153    def _wrap_output_spec_tensor_meta(
154        self,
155        op: OpOverload,
156        output_specs: OutputSpecType,
157        output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]],
158    ) -> None:
159        """
160        Wrap the output_specs with the tensor metadata from the output.
161        """
162
163        if isinstance(output_specs, DTensorSpec):
164            if not isinstance(output_tensor_meta, TensorMeta):
165                # Either error due to ShardingPropagator or due to incorrect OutputSpec
166                if not isinstance(output_tensor_meta, (tuple, list)):
167                    raise ValueError(
168                        "ShardingPropagator error: output does not have an associated TensorMeta"
169                    )
170                raise ValueError(
171                    f"For the op {op.name()}, `output_specs` has 1 output which does not equal the "
172                    f"number of op outputs: {len(output_tensor_meta)}."
173                )
174            output_specs.tensor_meta = output_tensor_meta
175        elif isinstance(output_specs, (tuple, list)):
176            if not isinstance(output_tensor_meta, (tuple, list)) or len(
177                output_specs
178            ) != len(output_tensor_meta):
179                raise ValueError(
180                    f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the "
181                    f"number of op outputs {_length(output_tensor_meta)}."
182                )
183            for i, spec in enumerate(output_specs):
184                if isinstance(spec, DTensorSpec):
185                    output_tensor_meta_i = output_tensor_meta[i]
186                    if not isinstance(output_tensor_meta_i, TensorMeta):
187                        raise ValueError(
188                            f"ShardingPropagator error: output {i} does not have an associated TensorMeta"
189                        )
190                    spec.tensor_meta = output_tensor_meta_i
191
192    def propagate(self, op_info: OpInfo) -> None:
193        # We cannot use an lru cache if we know that inputs will have dynamic shapes,
194        # because SymInts are not hashable.
195        # This is generally ok because this only happens during tracing in torch.compile,
196        # and tracing does not need to be as fast as eagermode DTensor usages.
197        if op_info.schema.has_symints:
198            output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
199        else:
200            output_sharding = cast(
201                OutputSharding, self.propagate_op_sharding(op_info.schema)
202            )
203        op_info.output_sharding = output_sharding
204
205    def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
206        """
207        Propagate the sharding for an operator given the op_schema.
208        """
209        # special case op, we don't need to propagate for local
210        # scalar. TODO: figure out a better way to handle this
211        if op_schema.op is aten._local_scalar_dense.default:
212            return OutputSharding(None, op_schema)
213
214        out_tensor_meta = self._propagate_tensor_meta(op_schema)
215
216        def spec_to_strategy(spec: object) -> object:
217            if isinstance(spec, DTensorSpec):
218                return OpStrategy([PlacementStrategy(spec)])
219            elif (
220                isinstance(spec, (list, tuple))
221                and len(spec) > 0
222                and isinstance(spec[0], DTensorSpec)
223            ):
224                # tensor list create tuple strategy
225                tuple_strategy = [spec_to_strategy(s) for s in spec]
226                tuple_strategy = cast(Sequence[StrategyType], tuple_strategy)
227                return TupleStrategy(
228                    tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy
229                )
230            else:
231                return spec
232
233        if op_schema.op in self.op_strategy_funcs:
234            # generate op strategy for the op.
235            mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema)
236            # swap the args spec with args strategies
237            args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema]
238
239            kwargs_op_strategy = {
240                k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items()
241            }
242
243            # construct a new OpSchema on args for strategy based propagation
244            strategy_schema: OpSchema = OpSchema(
245                op=op_schema.op,
246                args_schema=tuple(args_op_strategy),
247                kwargs_schema=kwargs_op_strategy,
248            )
249
250            op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema)
251
252            if isinstance(op_strategy, OpStrategy):
253                # single Op strategy
254                output_strategy = self._select_strategy(op_strategy)
255
256                # check if we need to redistribute the input
257                needs_redistribute = False
258                expected_input_specs: List[DTensorSpec] = []
259
260                # in case where the op does not specify input_specs and output_specs
261                # is a DTensorSpec, we use output_specs as the spec for each DTensor
262                # input arg.
263                if output_strategy.input_specs is None:
264                    assert isinstance(output_strategy.output_specs, DTensorSpec)
265
266                for idx, input_spec in enumerate(op_schema.args_spec):
267                    desired_spec = (
268                        output_strategy.output_spec
269                        if output_strategy.input_specs is None
270                        else output_strategy.input_specs[idx]
271                    )
272                    expected_input_specs.append(
273                        desired_spec.shallow_copy_with_tensor_meta(
274                            input_spec.tensor_meta
275                        )
276                    )
277                    if input_spec.placements != desired_spec.placements:
278                        needs_redistribute = True
279
280                suggestion_schema = None
281                if needs_redistribute:
282                    suggestion_schema = OpSchema(
283                        op_schema.op, tuple(expected_input_specs), {}
284                    )
285                    suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
286
287                # shape and stride args need to be modified for
288                # view ops and new factory ops, potentially
289                if op_schema.op in self.op_to_shape_and_stride_idx:
290                    assert isinstance(output_strategy.output_spec, DTensorSpec)
291                    # It happens when the output has the same shape as the input
292                    # and the input placements are not all Replicate().
293                    if output_strategy.output_spec.is_sharded():
294                        schema = suggestion_schema or op_schema
295                        assert isinstance(out_tensor_meta, TensorMeta)
296                        suggestion_schema = self._adjust_shape_and_stride_args(
297                            out_tensor_meta, schema, output_strategy.output_spec, mesh
298                        )
299                        needs_redistribute = True
300
301                # construct output spec for the op
302                if op_schema.return_type_tuple_tensor_like():
303                    # for ops that return multiple tensors and the output_specs is not
304                    # a tuple, we use a tuple of that single output spec as the new
305                    # output_specs
306                    output_specs: OutputSpecType = output_strategy.output_specs
307                    if isinstance(output_specs, DTensorSpec):
308                        output_specs = tuple(
309                            [
310                                # create a new DTensorSpec with the same placement as the
311                                # output_specs in output_strategy
312                                DTensorSpec(
313                                    mesh=output_specs.mesh,
314                                    placements=output_specs.placements,
315                                    tensor_meta=output_specs.tensor_meta,
316                                )
317                                for _ in range(len(op_schema.op._schema.returns))
318                            ]
319                        )
320                elif op_schema.return_type_tensor():
321                    output_specs = output_strategy.output_specs
322                else:
323                    output_specs = None
324
325                output_sharding = OutputSharding(
326                    output_specs,
327                    suggestion_schema,
328                    needs_redistribute=needs_redistribute,
329                )
330            elif isinstance(op_strategy, TupleStrategy):
331                # tuple strategy output sharding processing
332                # runtime selected placement strategy for each TupleStrategy input arg
333                selected_strategies: List[PlacementStrategy] = []
334                out_spec_list: List[DTensorSpec] = []
335                for strategy in op_strategy.childs:
336                    assert isinstance(strategy, OpStrategy)
337                    selected_strategy = self._select_strategy(strategy)
338                    selected_strategies.append(selected_strategy)
339                    out_spec_list.append(selected_strategy.output_spec)
340
341                needs_redistribute = False
342                suggestion_args: List[object] = []
343                tensor_or_list_tensor_arg_idx = 0
344
345                for arg in op_schema.args_schema:
346                    if (
347                        arg
348                        and isinstance(arg, (list, tuple))
349                        and isinstance(arg[0], DTensorSpec)
350                    ):
351                        expected_input_spec_list: List[DTensorSpec] = []
352                        for idx, arg_spec in enumerate(arg):
353                            expected_input_spec = selected_strategies[idx].input_spec(
354                                tensor_or_list_tensor_arg_idx
355                            )
356                            expected_input_spec = (
357                                expected_input_spec.shallow_copy_with_tensor_meta(
358                                    arg_spec.tensor_meta
359                                )
360                            )
361                            if arg_spec.placements != expected_input_spec.placements:
362                                needs_redistribute = True
363                            expected_input_spec_list.append(expected_input_spec)
364                        suggestion_args.append(
365                            tuple(expected_input_spec_list)
366                            if isinstance(arg, tuple)
367                            else expected_input_spec_list
368                        )
369                        tensor_or_list_tensor_arg_idx += 1
370
371                    elif isinstance(arg, DTensorSpec):
372                        expected_input_spec = selected_strategies[0].input_spec(
373                            tensor_or_list_tensor_arg_idx
374                        )
375                        expected_input_spec = (
376                            expected_input_spec.shallow_copy_with_tensor_meta(
377                                arg.tensor_meta
378                            )
379                        )
380                        if arg.placements != expected_input_spec.placements:
381                            needs_redistribute = True
382                        suggestion_args.append(expected_input_spec)
383                        tensor_or_list_tensor_arg_idx += 1
384                    else:
385                        suggestion_args.append(arg)
386
387                suggestion_schema = None
388                if needs_redistribute:
389                    suggestion_schema = OpSchema(
390                        op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema
391                    )
392
393                output_sharding = OutputSharding(
394                    tuple(out_spec_list) if out_tensor_meta is not None else None,
395                    suggestion_schema,
396                    needs_redistribute=needs_redistribute,
397                )
398            else:
399                raise ValueError("Unsupported op strategy type")
400
401            # associate the output sharding with the output tensor metadata
402            self._wrap_output_spec_tensor_meta(
403                op_schema.op, output_sharding.output_spec, out_tensor_meta
404            )
405            return output_sharding
406        elif op_schema.op in self.op_to_rules:
407            # propagate the sharding with rule
408            sharding_prop_func = self.op_to_rules[op_schema.op]
409
410            # step 1. there's sharding propagation rule, run
411            # sharding propagation to get the output sharding
412            try:
413                output_sharding = sharding_prop_func(op_schema)
414            except NotImplementedError as e:
415                raise e
416            except Exception as e:
417                raise RuntimeError(
418                    f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}"
419                ) from e
420
421            # step 2. if can't get output_spec from sharding
422            # propagation (i.e. no rules apply for input
423            # placements), we return the output sharding
424            # with schema suggestions, which can be used to
425            # decide how to do redistribute on inputs
426            if output_sharding.output_spec is None:
427                if output_sharding.redistribute_schema is None:
428                    raise RuntimeError(
429                        f"Sharding propagation failed on op {op_schema}!"
430                    )
431                else:
432                    # we do auto redistribute on inputs if necessary
433                    # run sharding propagation again with suggested schema
434                    propagation_res = sharding_prop_func(
435                        output_sharding.redistribute_schema
436                    )
437                    # we set the output sharding with the new propagation result
438                    # so that dispatching know both output_spec and redistribute_schema
439                    # exist, which indicates a reshard is needed
440                    output_sharding.output_spec = propagation_res.output_spec
441                    output_sharding.needs_redistribute = True
442
443            # associate the output sharding with the output tensor metadata
444            self._wrap_output_spec_tensor_meta(
445                op_schema.op, output_sharding.output_spec, out_tensor_meta
446            )
447
448            return output_sharding
449        else:
450            raise NotImplementedError(
451                f"Operator {op_schema.op} does not have a sharding strategy registered."
452            )
453
454    def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy:
455        if len(strategy.strategies) == 1:
456            # short cut with only one possible strategy
457            return strategy.strategies[0]
458
459        strategy_costs: List[float] = []
460        for strtg in strategy.strategies:
461            assert (
462                strtg.redistribute_cost is not None
463            ), "must set redistribute cost each strategy!"
464            redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
465            strategy_costs.append(redistribute_cost)
466
467        # for eager execution, we just select the one with the minimal redistribute cost
468        return strategy.strategies[strategy_costs.index(min(strategy_costs))]
469
470    def _adjust_shape_and_stride_args(
471        self,
472        out_tensor_meta: TensorMeta,
473        schema: OpSchema,
474        spec: DTensorSpec,
475        mesh: DeviceMesh,
476    ) -> OpSchema:
477        shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op]
478        if isinstance(shape_stride_idx, tuple):
479            shape_idx, stride_idx = shape_stride_idx
480        else:
481            shape_idx = shape_stride_idx
482            stride_idx = None
483
484        expected_input_schema = list(schema.args_schema)
485        # adjust shape to be the same as that of the _local_tensor
486        # of the DTensor input arg at index 0, which is inferred
487        expected_input_schema[shape_idx] = compute_local_shape(
488            out_tensor_meta.shape, mesh, spec.placements
489        )
490
491        # adjust the stride arg for aten.new_empty_strided.default
492        if stride_idx:
493            expected_input_schema[stride_idx] = compute_local_stride(
494                out_tensor_meta.stride, mesh, spec.placements
495            )
496
497        return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)
498