• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2from typing import List, Sequence, Tuple
3
4import torch
5from torch.distributed.device_mesh import DeviceMesh
6from torch.distributed.tensor._dtensor_spec import DTensorSpec
7from torch.distributed.tensor._op_schema import (
8    _is_inplace_op,
9    _is_out_variant_op,
10    OpSchema,
11    OpStrategy,
12    PlacementStrategy,
13    RuntimeSchemaInfo,
14    StrategyType,
15    TupleStrategy,
16)
17from torch.distributed.tensor._ops.utils import (
18    generate_redistribute_costs,
19    infer_broadcast_dims_map,
20    map_placements_after_broadcast,
21    normalize_dim,
22    register_op_strategy,
23)
24from torch.distributed.tensor.placement_types import (
25    Partial,
26    Placement,
27    Replicate,
28    Shard,
29)
30
31
32aten = torch.ops.aten
33# leave the remaining pointwise_ops list here for convenience,
34# Below ops are some pointwise ops that are yet to be supported,
35# they might not be a complete list.
36# pointwise_ops = [
37#     "fake_quantize_per_channel_affine",
38#     "fake_quantize_per_tensor_affine",
39#     "floor_divide",  # floor_divide is deprecated
40#     "frexp",  # multiple output pointwise op, need to add support
41#     "gradient",  #  need investigation on this op
42#     "imag",  # complex data type only
43#     "quantized_batch_norm",
44#     "quantized_max_pool1d",
45#     "quantized_max_pool2d",
46#     "real",  # complex data type only
47# ]
48
49
50linear_pointwise_ops = [
51    aten.div.Scalar,  # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
52    aten.div_.Scalar,  # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
53    aten.to.dtype,
54    aten.add.Tensor,
55    aten.add_.Tensor,
56]
57
58
59pointwise_ops = [
60    # please keep the entries below alphabetically sorted
61    aten.__ilshift__.Scalar,
62    aten.__ilshift__.Tensor,
63    aten.__irshift__.Scalar,
64    aten.__irshift__.Tensor,
65    aten.__lshift__.Scalar,
66    aten.__lshift__.Tensor,
67    aten.__rshift__.Scalar,
68    aten.__rshift__.Tensor,
69    aten._conj.default,
70    aten.abs.default,
71    aten.abs.out,
72    aten.abs_.default,
73    aten.acos.default,
74    aten.acos.out,
75    aten.acos_.default,
76    aten.acosh.default,
77    aten.acosh.out,
78    aten.acosh_.default,
79    aten.add.Scalar,
80    aten.add.out,
81    aten.add_.Scalar,
82    aten.addcdiv.default,
83    aten.addcdiv.out,
84    aten.addcdiv_.default,
85    aten.addcmul.default,
86    aten.addcmul.out,
87    aten.addcmul_.default,
88    aten.angle.default,
89    aten.angle.out,
90    aten.asin.default,
91    aten.asin.out,
92    aten.asin_.default,
93    aten.asinh.default,
94    aten.asinh.out,
95    aten.asinh_.default,
96    aten.atan.default,
97    aten.atan.out,
98    aten.atan2.default,
99    aten.atan2.out,
100    aten.atan2_.default,
101    aten.atan_.default,
102    aten.atanh.default,
103    aten.atanh.out,
104    aten.atanh_.default,
105    aten.bitwise_and.Scalar,
106    aten.bitwise_and.Scalar_Tensor,
107    aten.bitwise_and.Scalar_out,
108    aten.bitwise_and.Tensor,
109    aten.bitwise_and.Tensor_out,
110    aten.bitwise_and_.Scalar,
111    aten.bitwise_and_.Tensor,
112    aten.bitwise_left_shift.Scalar_Tensor,
113    aten.bitwise_left_shift.Tensor,
114    aten.bitwise_left_shift.Tensor_Scalar,
115    aten.bitwise_left_shift.Tensor_Scalar_out,
116    aten.bitwise_left_shift.Tensor_out,
117    aten.bitwise_left_shift_.Tensor,
118    aten.bitwise_left_shift_.Tensor_Scalar,
119    aten.bitwise_not.default,
120    aten.bitwise_not.out,
121    aten.bitwise_not_.default,
122    aten.bitwise_or.Scalar,
123    aten.bitwise_or.Scalar_Tensor,
124    aten.bitwise_or.Scalar_out,
125    aten.bitwise_or.Tensor,
126    aten.bitwise_or.Tensor_out,
127    aten.bitwise_or_.Scalar,
128    aten.bitwise_or_.Tensor,
129    aten.bitwise_right_shift.Scalar_Tensor,
130    aten.bitwise_right_shift.Tensor,
131    aten.bitwise_right_shift.Tensor_Scalar,
132    aten.bitwise_right_shift.Tensor_Scalar_out,
133    aten.bitwise_right_shift.Tensor_out,
134    aten.bitwise_right_shift_.Tensor,
135    aten.bitwise_right_shift_.Tensor_Scalar,
136    aten.bitwise_xor.Scalar,
137    aten.bitwise_xor.Scalar_Tensor,
138    aten.bitwise_xor.Scalar_out,
139    aten.bitwise_xor.Tensor,
140    aten.bitwise_xor.Tensor_out,
141    aten.bitwise_xor_.Scalar,
142    aten.bitwise_xor_.Tensor,
143    aten.ceil.default,
144    aten.ceil.out,
145    aten.ceil_.default,
146    aten.clamp.default,
147    aten.clamp.out,
148    aten.clamp_.default,
149    aten.clip.default,
150    aten.clip.out,
151    aten.clip_.default,
152    aten.conj_physical.default,
153    aten.conj_physical.out,
154    aten.conj_physical_.default,
155    aten.copysign.Scalar,
156    aten.copysign.Scalar_out,
157    aten.copysign.Tensor,
158    aten.copysign.out,
159    aten.copysign_.Scalar,
160    aten.copysign_.Tensor,
161    aten.cos.default,
162    aten.cos.out,
163    aten.cos_.default,
164    aten.cosh.default,
165    aten.cosh.out,
166    aten.cosh_.default,
167    aten.deg2rad.default,
168    aten.deg2rad.out,
169    aten.deg2rad_.default,
170    aten.digamma.default,
171    aten.digamma.out,
172    aten.digamma_.default,
173    aten.div.Tensor,
174    aten.div.Tensor_mode,
175    aten.div.out,
176    aten.div.out_mode,
177    aten.div_.Tensor,
178    aten.div_.Tensor_mode,
179    aten.eq.Tensor,
180    aten.eq.Tensor_out,
181    aten.eq.Scalar,
182    aten.eq.Scalar_out,
183    aten.erf.default,
184    aten.erf.out,
185    aten.erf_.default,
186    aten.erfc.default,
187    aten.erfc.out,
188    aten.erfc_.default,
189    aten.erfinv.default,
190    aten.erfinv.out,
191    aten.erfinv_.default,
192    aten.exp.default,
193    aten.exp.out,
194    aten.exp2.default,
195    aten.exp2.out,
196    aten.exp2_.default,
197    aten.exp_.default,
198    aten.expm1.default,
199    aten.expm1.out,
200    aten.expm1_.default,
201    aten.float_power.Scalar,
202    aten.float_power.Scalar_out,
203    aten.float_power.Tensor_Scalar,
204    aten.float_power.Tensor_Scalar_out,
205    aten.float_power.Tensor_Tensor,
206    aten.float_power.Tensor_Tensor_out,
207    aten.float_power_.Scalar,
208    aten.float_power_.Tensor,
209    aten.floor.default,
210    aten.floor.out,
211    aten.floor_.default,
212    aten.fmod.Scalar,
213    aten.fmod.Scalar_out,
214    aten.fmod.Tensor,
215    aten.fmod.Tensor_out,
216    aten.fmod_.Scalar,
217    aten.fmod_.Tensor,
218    aten.frac.default,
219    aten.frac.out,
220    aten.frac_.default,
221    aten.ge.Scalar,
222    aten.ge.Tensor,
223    aten.gelu.default,
224    aten.gt.Tensor,
225    aten.gt.Tensor_out,
226    aten.gt.Scalar,
227    aten.gt.Scalar_out,
228    aten.gt.Scalar,
229    aten.gt.Tensor,
230    aten.hypot.default,
231    aten.hypot.out,
232    aten.hypot_.default,
233    aten.i0.default,
234    aten.i0.out,
235    aten.i0_.default,
236    aten.igamma.default,
237    aten.igamma.out,
238    aten.igamma_.default,
239    aten.igammac.default,
240    aten.igammac.out,
241    aten.igammac_.default,
242    aten.isinf.default,
243    aten.isnan.default,
244    aten.isneginf.default,
245    aten.isneginf.out,
246    aten.isposinf.default,
247    aten.isposinf.out,
248    aten.ldexp.default,
249    aten.ldexp.out,
250    aten.ldexp_.default,
251    aten.lt.Tensor,
252    aten.lt.Tensor_out,
253    aten.lt.Scalar,
254    aten.lt.Scalar_out,
255    aten.le.Scalar,
256    aten.le.Tensor,
257    aten.lerp.Scalar,
258    aten.lerp.Scalar_out,
259    aten.lerp.Tensor,
260    aten.lerp.Tensor_out,
261    aten.lerp_.Scalar,
262    aten.lerp_.Tensor,
263    aten.lgamma.default,
264    aten.lgamma.out,
265    aten.lgamma_.default,
266    aten.log.default,
267    aten.log.out,
268    aten.log10.default,
269    aten.log10.out,
270    aten.log10_.default,
271    aten.log1p.default,
272    aten.log1p.out,
273    aten.log1p_.default,
274    aten.log2.default,
275    aten.log2.out,
276    aten.log2_.default,
277    aten.log_.default,
278    aten.logaddexp.default,
279    aten.logaddexp.out,
280    aten.logaddexp2.default,
281    aten.logaddexp2.out,
282    aten.logical_and.default,
283    aten.logical_and.out,
284    aten.logical_and_.default,
285    aten.logical_not.default,
286    aten.logical_not.out,
287    aten.logical_not_.default,
288    aten.logical_or.default,
289    aten.logical_or.out,
290    aten.logical_or_.default,
291    aten.logical_xor.default,
292    aten.logical_xor.out,
293    aten.logical_xor_.default,
294    aten.logit.default,
295    aten.logit.out,
296    aten.logit_.default,
297    aten.masked_fill.Scalar,
298    aten.maximum.out,
299    aten.mul.Scalar,
300    aten.mul.Tensor,
301    aten.mul.out,
302    aten.mul_.Scalar,
303    aten.mul_.Tensor,
304    aten.mvlgamma.default,
305    aten.mvlgamma.out,
306    aten.mvlgamma_.default,
307    aten.native_dropout_backward.default,
308    aten.native_dropout_backward.out,
309    aten.nan_to_num.default,
310    aten.nan_to_num.out,
311    aten.nan_to_num_.default,
312    aten.ne.Scalar,
313    aten.neg.default,
314    aten.neg.out,
315    aten.neg_.default,
316    aten.nextafter.default,
317    aten.nextafter.out,
318    aten.nextafter_.default,
319    aten.polygamma.default,
320    aten.polygamma.out,
321    aten.polygamma_.default,
322    aten.positive.default,
323    aten.pow.Scalar,
324    aten.pow.Scalar_out,
325    aten.pow.Tensor_Scalar,
326    aten.pow.Tensor_Scalar_out,
327    aten.pow.Tensor_Tensor,
328    aten.pow.Tensor_Tensor_out,
329    aten.pow_.Scalar,
330    aten.pow_.Tensor,
331    aten.reciprocal.default,
332    aten.reciprocal.out,
333    aten.reciprocal_.default,
334    aten.rad2deg.default,
335    aten.rad2deg.out,
336    aten.rad2deg_.default,
337    aten.relu.default,
338    aten.relu_.default,
339    aten.remainder.Scalar,
340    aten.remainder.Scalar_Tensor,
341    aten.remainder.Scalar_out,
342    aten.remainder.Tensor,
343    aten.remainder.Tensor_out,
344    aten.remainder_.Scalar,
345    aten.remainder_.Tensor,
346    aten.round.decimals,
347    aten.round.decimals_out,
348    aten.round.default,
349    aten.round.out,
350    aten.round_.decimals,
351    aten.round_.default,
352    aten.rsqrt.default,
353    aten.rsqrt.out,
354    aten.rsqrt_.default,
355    aten.rsub.Scalar,
356    aten.sgn.default,
357    aten.sgn.out,
358    aten.sgn_.default,
359    aten.sigmoid.default,
360    aten.sigmoid.out,
361    aten.sigmoid_.default,
362    aten.sign.default,
363    aten.sign.out,
364    aten.sign_.default,
365    aten.signbit.default,
366    aten.signbit.out,
367    aten.silu.default,
368    aten.silu.out,
369    aten.sin.default,
370    aten.sin.out,
371    aten.sin_.default,
372    aten.sinc.default,
373    aten.sinc.out,
374    aten.sinc_.default,
375    aten.sinh.default,
376    aten.sinh.out,
377    aten.sinh_.default,
378    aten.sqrt.default,
379    aten.sqrt.out,
380    aten.sqrt_.default,
381    aten.square.default,
382    aten.square.out,
383    aten.square_.default,
384    aten.sub.Scalar,
385    aten.sub.Tensor,
386    aten.sub.out,
387    aten.sub_.Scalar,
388    aten.sub_.Tensor,
389    aten.tan.default,
390    aten.tan.out,
391    aten.tan_.default,
392    aten.tanh.default,
393    aten.tanh.out,
394    aten.tanh_.default,
395    aten.true_divide.Tensor,
396    aten.trunc.default,
397    aten.trunc.out,
398    aten.trunc_.default,
399    aten.where.self,
400    aten.where.self_out,
401    aten.xlogy.OutScalar_Self,
402    aten.xlogy.OutScalar_Other,
403    aten.xlogy.OutTensor,
404    aten.xlogy.Scalar_Other,
405    aten.xlogy.Scalar_Self,
406    aten.xlogy.Tensor,
407    aten.xlogy_.Scalar_Other,
408    aten.xlogy_.Tensor,
409    # backward point-wise ops
410    # please keep the entries below alphabetically sorted
411    aten.gelu_backward.default,
412    aten.sigmoid_backward.default,
413    aten.silu_backward.default,
414    aten.tanh_backward.default,
415    aten.threshold_backward.default,
416]
417
418
419def pointwise_strategy(
420    mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
421) -> OpStrategy:
422    max_shards_strategy_index = -1
423    max_shards = -1
424
425    if _is_inplace_op(op_schema.op):
426        # inplace op should follow the first arg strategy
427        followed_strategy = op_schema.args_schema[0]
428    elif _is_out_variant_op(op_schema.op):
429        # out variant op should follow the out kwarg strategy
430        followed_strategy = op_schema.kwargs_schema["out"]
431    else:
432        # normal pointwise op, we choose to follow the arg with
433        # the max shards in case operands needs reshard
434        for idx, arg_strategy in enumerate(op_schema.args_schema):
435            if not isinstance(arg_strategy, OpStrategy):
436                continue
437
438            arg_max_shards = arg_strategy.max_num_shards()
439            if arg_max_shards > max_shards:
440                max_shards_strategy_index = idx
441                max_shards = arg_max_shards
442
443        followed_strategy = op_schema.args_schema[max_shards_strategy_index]
444
445    assert isinstance(
446        followed_strategy, OpStrategy
447    ), f"no strategy to follow for {op_schema}!"
448    return common_pointwise_strategy(
449        mesh, op_schema.args_schema, followed_strategy, linearity
450    )
451
452
453def common_pointwise_strategy(
454    mesh: DeviceMesh,
455    args_schema: Sequence[object],
456    followed_strategy: OpStrategy,
457    linearity: bool,
458) -> OpStrategy:
459    # handle broadcasting
460    common_shape = torch.broadcast_shapes(
461        *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)]
462    )
463    pointwise_strategy = OpStrategy([])
464
465    for placement_strategy in followed_strategy.strategies:
466        spec_to_follow = placement_strategy.output_spec
467        out_placements: List[Placement] = []
468        for placement in spec_to_follow.placements:
469            if isinstance(placement, Shard):
470                shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape))
471                common_ndim = len(common_shape)
472                new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
473                out_placements.append(Shard(new_shard_dim))
474            elif isinstance(placement, Partial) and not linearity:
475                # clear the partial placemnet if op does not support linearity
476                # by default we just replicate the partial, need to see if this
477                # is optimal for all cases
478                out_placements.append(Replicate())
479            else:
480                out_placements.append(placement)
481
482        input_specs: List[DTensorSpec] = []
483        redistribute_costs: List[List[float]] = []
484        for input_arg in args_schema:
485            if isinstance(input_arg, OpStrategy):
486                # every arg follow the out_placements, but need to handle broadcasting
487                input_arg_spec = input_arg.strategies[0].output_spec
488                input_arg_dims_map = infer_broadcast_dims_map(
489                    common_shape, input_arg_spec.shape
490                )
491                input_target_placements = map_placements_after_broadcast(
492                    tuple(out_placements),
493                    common_shape,
494                    input_arg_dims_map,
495                )
496                input_arg_target_spec = DTensorSpec(
497                    mesh=mesh,
498                    placements=input_target_placements,
499                    tensor_meta=input_arg_spec.tensor_meta,
500                )
501                input_specs.append(input_arg_target_spec)
502                redistribute_costs.append(
503                    generate_redistribute_costs(input_arg, input_arg_target_spec)
504                )
505
506        pointwise_strategy.strategies.append(
507            PlacementStrategy(
508                output_specs=DTensorSpec(
509                    mesh=mesh,
510                    placements=tuple(out_placements),
511                ),
512                input_specs=input_specs,
513                redistribute_cost=redistribute_costs,
514            )
515        )
516    return pointwise_strategy
517
518
519def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
520    """
521    Linear pointwise operators can propagate pending reductions.
522    For example, c = add(a, b); if a is pending sum, then c will be
523    pending sum as well without any communication overhead.
524    """
525    return pointwise_strategy(mesh, op_schema, linearity=True)
526
527
528for op in linear_pointwise_ops:
529    register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
530        linear_pointwise_strategy
531    )
532
533for op in pointwise_ops:
534    register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
535        pointwise_strategy
536    )
537
538
539# TODO: add all for_each ops
540for_each_ops = [
541    aten._foreach_abs.default,
542    aten._foreach_abs_.default,
543    aten._foreach_addcdiv_.Scalar,
544    aten._foreach_addcdiv_.ScalarList,
545    aten._foreach_addcdiv_.Tensor,
546    aten._foreach_addcmul.Scalar,
547    aten._foreach_addcmul_.Scalar,
548    aten._foreach_addcmul_.ScalarList,
549    aten._foreach_addcmul_.Tensor,
550    aten._foreach_clamp_max_.Scalar,
551    aten._foreach_clamp_min_.Scalar,
552    aten._foreach_div_.List,
553    aten._foreach_div_.Scalar,
554    aten._foreach_div_.ScalarList,
555    aten._foreach_div_.Tensor,
556    aten._foreach_div.List,
557    aten._foreach_div.Scalar,
558    aten._foreach_div.ScalarList,
559    aten._foreach_div.Tensor,
560    aten._foreach_lerp_.Scalar,
561    aten._foreach_maximum_.List,
562    aten._foreach_mul.Scalar,
563    aten._foreach_mul.ScalarList,
564    aten._foreach_mul.Tensor,
565    aten._foreach_mul.List,
566    aten._foreach_mul_.Scalar,
567    aten._foreach_mul_.ScalarList,
568    aten._foreach_mul_.Tensor,
569    aten._foreach_mul_.List,
570    aten._foreach_neg.default,
571    aten._foreach_neg_.default,
572    aten._foreach_reciprocal_.default,
573    aten._foreach_sub.Scalar,
574    aten._foreach_sub_.Scalar,
575    aten._foreach_sub.List,
576    aten._foreach_sub_.List,
577    aten._foreach_sub.ScalarList,
578    aten._foreach_sub_.ScalarList,
579    aten._foreach_sqrt.default,
580    aten._foreach_sqrt_.default,
581    aten._foreach_zero_.default,
582    aten._foreach_exp.default,
583    aten._foreach_exp_.default,
584    aten._foreach_cos.default,
585    aten._foreach_cos_.default,
586    aten._foreach_log.default,
587    aten._foreach_log_.default,
588    aten._amp_foreach_non_finite_check_and_unscale_.default,
589]
590
591for_each_linearity_ops = [
592    aten._foreach_add.Scalar,
593    aten._foreach_add_.Scalar,
594    aten._foreach_add_.ScalarList,
595    aten._foreach_add.List,
596    aten._foreach_add_.List,
597]
598
599
600def list_pointwise_strategy(
601    mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
602) -> StrategyType:
603    """
604    Apply the pointwise strategy to the zipped arguments. For example, if we
605    run a foreach add of two lists l1 and l2, then we apply the pointwise
606    strategy on each pair (l1[i], l2[i]). If the first argument is a list but
607    the second (or later) one is a tensor, then we broadcast the tensor by
608    replicating it into a list with the length of the first argument.
609
610    Args:
611        mesh (DeviceMesh): device mesh for pointwise ops
612        op_schema (OpSchema): schema of the operator to generate strategy for
613        linearity (bool): specify whether op(a) + op(b) = op(a + b)
614
615    Returns:
616        OpStrategy: generated strategy
617    """
618
619    def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]:
620        first_arg = args_schema[0]
621        assert isinstance(first_arg, TupleStrategy)
622        strategy_len = len(first_arg.childs)
623        tuple_strategies: List[TupleStrategy] = []
624        for arg_idx, arg in enumerate(args_schema):
625            if isinstance(arg, TupleStrategy):
626                # every tuple strategy should have the same length
627                assert len(arg.childs) == strategy_len
628                tuple_strategies.append(arg)
629            elif isinstance(arg, OpStrategy):
630                if arg_idx > 0:  # implicitly broadcast
631                    tuple_strategies.append(
632                        TupleStrategy([arg for _ in range(strategy_len)])
633                    )
634                else:
635                    raise RuntimeError(
636                        f"list op only supports tuple strategy! {op_schema}"
637                    )
638        return tuple_strategies
639
640    args_strategies = args_tuple_strategies(op_schema.args_schema)
641    follow_strategy: TupleStrategy = args_strategies[0]
642    list_strategy: List[OpStrategy] = []
643    for child_idx, child_strtgy in enumerate(follow_strategy.childs):
644        assert isinstance(child_strtgy, OpStrategy)
645        args_schema: List[StrategyType] = [
646            arg_strategy.childs[child_idx] for arg_strategy in args_strategies
647        ]
648        pointwise_strategy: OpStrategy = common_pointwise_strategy(
649            mesh, args_schema, child_strtgy, linearity
650        )
651        list_strategy.append(pointwise_strategy)
652    return TupleStrategy(list_strategy)
653
654
655def list_linear_pointwise_strategy(
656    mesh: DeviceMesh, op_schema: OpSchema
657) -> StrategyType:
658    """
659    for each list op stratgy that supports linearity
660    """
661    return list_pointwise_strategy(mesh, op_schema, linearity=True)
662
663
664for op in for_each_ops:
665    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
666        list_pointwise_strategy
667    )
668
669for op in for_each_linearity_ops:
670    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
671        list_linear_pointwise_strategy
672    )
673
674fused_ops = [
675    aten._fused_adam_.default,
676    aten._fused_adam.default,
677    aten._fused_adam.tensor_lr,
678    aten._fused_adam_.tensor_lr,
679    aten._fused_adamw_.default,
680    aten._fused_adamw.default,
681    aten._fused_adamw.tensor_lr,
682    aten._fused_adamw_.tensor_lr,
683]
684
685for op in fused_ops:
686    register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
687        list_pointwise_strategy
688    )
689