1# Copyright (c) Meta Platforms, Inc. and affiliates 2from typing import List, Sequence, Tuple 3 4import torch 5from torch.distributed.device_mesh import DeviceMesh 6from torch.distributed.tensor._dtensor_spec import DTensorSpec 7from torch.distributed.tensor._op_schema import ( 8 _is_inplace_op, 9 _is_out_variant_op, 10 OpSchema, 11 OpStrategy, 12 PlacementStrategy, 13 RuntimeSchemaInfo, 14 StrategyType, 15 TupleStrategy, 16) 17from torch.distributed.tensor._ops.utils import ( 18 generate_redistribute_costs, 19 infer_broadcast_dims_map, 20 map_placements_after_broadcast, 21 normalize_dim, 22 register_op_strategy, 23) 24from torch.distributed.tensor.placement_types import ( 25 Partial, 26 Placement, 27 Replicate, 28 Shard, 29) 30 31 32aten = torch.ops.aten 33# leave the remaining pointwise_ops list here for convenience, 34# Below ops are some pointwise ops that are yet to be supported, 35# they might not be a complete list. 36# pointwise_ops = [ 37# "fake_quantize_per_channel_affine", 38# "fake_quantize_per_tensor_affine", 39# "floor_divide", # floor_divide is deprecated 40# "frexp", # multiple output pointwise op, need to add support 41# "gradient", # need investigation on this op 42# "imag", # complex data type only 43# "quantized_batch_norm", 44# "quantized_max_pool1d", 45# "quantized_max_pool2d", 46# "real", # complex data type only 47# ] 48 49 50linear_pointwise_ops = [ 51 aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. 52 aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. 53 aten.to.dtype, 54 aten.add.Tensor, 55 aten.add_.Tensor, 56] 57 58 59pointwise_ops = [ 60 # please keep the entries below alphabetically sorted 61 aten.__ilshift__.Scalar, 62 aten.__ilshift__.Tensor, 63 aten.__irshift__.Scalar, 64 aten.__irshift__.Tensor, 65 aten.__lshift__.Scalar, 66 aten.__lshift__.Tensor, 67 aten.__rshift__.Scalar, 68 aten.__rshift__.Tensor, 69 aten._conj.default, 70 aten.abs.default, 71 aten.abs.out, 72 aten.abs_.default, 73 aten.acos.default, 74 aten.acos.out, 75 aten.acos_.default, 76 aten.acosh.default, 77 aten.acosh.out, 78 aten.acosh_.default, 79 aten.add.Scalar, 80 aten.add.out, 81 aten.add_.Scalar, 82 aten.addcdiv.default, 83 aten.addcdiv.out, 84 aten.addcdiv_.default, 85 aten.addcmul.default, 86 aten.addcmul.out, 87 aten.addcmul_.default, 88 aten.angle.default, 89 aten.angle.out, 90 aten.asin.default, 91 aten.asin.out, 92 aten.asin_.default, 93 aten.asinh.default, 94 aten.asinh.out, 95 aten.asinh_.default, 96 aten.atan.default, 97 aten.atan.out, 98 aten.atan2.default, 99 aten.atan2.out, 100 aten.atan2_.default, 101 aten.atan_.default, 102 aten.atanh.default, 103 aten.atanh.out, 104 aten.atanh_.default, 105 aten.bitwise_and.Scalar, 106 aten.bitwise_and.Scalar_Tensor, 107 aten.bitwise_and.Scalar_out, 108 aten.bitwise_and.Tensor, 109 aten.bitwise_and.Tensor_out, 110 aten.bitwise_and_.Scalar, 111 aten.bitwise_and_.Tensor, 112 aten.bitwise_left_shift.Scalar_Tensor, 113 aten.bitwise_left_shift.Tensor, 114 aten.bitwise_left_shift.Tensor_Scalar, 115 aten.bitwise_left_shift.Tensor_Scalar_out, 116 aten.bitwise_left_shift.Tensor_out, 117 aten.bitwise_left_shift_.Tensor, 118 aten.bitwise_left_shift_.Tensor_Scalar, 119 aten.bitwise_not.default, 120 aten.bitwise_not.out, 121 aten.bitwise_not_.default, 122 aten.bitwise_or.Scalar, 123 aten.bitwise_or.Scalar_Tensor, 124 aten.bitwise_or.Scalar_out, 125 aten.bitwise_or.Tensor, 126 aten.bitwise_or.Tensor_out, 127 aten.bitwise_or_.Scalar, 128 aten.bitwise_or_.Tensor, 129 aten.bitwise_right_shift.Scalar_Tensor, 130 aten.bitwise_right_shift.Tensor, 131 aten.bitwise_right_shift.Tensor_Scalar, 132 aten.bitwise_right_shift.Tensor_Scalar_out, 133 aten.bitwise_right_shift.Tensor_out, 134 aten.bitwise_right_shift_.Tensor, 135 aten.bitwise_right_shift_.Tensor_Scalar, 136 aten.bitwise_xor.Scalar, 137 aten.bitwise_xor.Scalar_Tensor, 138 aten.bitwise_xor.Scalar_out, 139 aten.bitwise_xor.Tensor, 140 aten.bitwise_xor.Tensor_out, 141 aten.bitwise_xor_.Scalar, 142 aten.bitwise_xor_.Tensor, 143 aten.ceil.default, 144 aten.ceil.out, 145 aten.ceil_.default, 146 aten.clamp.default, 147 aten.clamp.out, 148 aten.clamp_.default, 149 aten.clip.default, 150 aten.clip.out, 151 aten.clip_.default, 152 aten.conj_physical.default, 153 aten.conj_physical.out, 154 aten.conj_physical_.default, 155 aten.copysign.Scalar, 156 aten.copysign.Scalar_out, 157 aten.copysign.Tensor, 158 aten.copysign.out, 159 aten.copysign_.Scalar, 160 aten.copysign_.Tensor, 161 aten.cos.default, 162 aten.cos.out, 163 aten.cos_.default, 164 aten.cosh.default, 165 aten.cosh.out, 166 aten.cosh_.default, 167 aten.deg2rad.default, 168 aten.deg2rad.out, 169 aten.deg2rad_.default, 170 aten.digamma.default, 171 aten.digamma.out, 172 aten.digamma_.default, 173 aten.div.Tensor, 174 aten.div.Tensor_mode, 175 aten.div.out, 176 aten.div.out_mode, 177 aten.div_.Tensor, 178 aten.div_.Tensor_mode, 179 aten.eq.Tensor, 180 aten.eq.Tensor_out, 181 aten.eq.Scalar, 182 aten.eq.Scalar_out, 183 aten.erf.default, 184 aten.erf.out, 185 aten.erf_.default, 186 aten.erfc.default, 187 aten.erfc.out, 188 aten.erfc_.default, 189 aten.erfinv.default, 190 aten.erfinv.out, 191 aten.erfinv_.default, 192 aten.exp.default, 193 aten.exp.out, 194 aten.exp2.default, 195 aten.exp2.out, 196 aten.exp2_.default, 197 aten.exp_.default, 198 aten.expm1.default, 199 aten.expm1.out, 200 aten.expm1_.default, 201 aten.float_power.Scalar, 202 aten.float_power.Scalar_out, 203 aten.float_power.Tensor_Scalar, 204 aten.float_power.Tensor_Scalar_out, 205 aten.float_power.Tensor_Tensor, 206 aten.float_power.Tensor_Tensor_out, 207 aten.float_power_.Scalar, 208 aten.float_power_.Tensor, 209 aten.floor.default, 210 aten.floor.out, 211 aten.floor_.default, 212 aten.fmod.Scalar, 213 aten.fmod.Scalar_out, 214 aten.fmod.Tensor, 215 aten.fmod.Tensor_out, 216 aten.fmod_.Scalar, 217 aten.fmod_.Tensor, 218 aten.frac.default, 219 aten.frac.out, 220 aten.frac_.default, 221 aten.ge.Scalar, 222 aten.ge.Tensor, 223 aten.gelu.default, 224 aten.gt.Tensor, 225 aten.gt.Tensor_out, 226 aten.gt.Scalar, 227 aten.gt.Scalar_out, 228 aten.gt.Scalar, 229 aten.gt.Tensor, 230 aten.hypot.default, 231 aten.hypot.out, 232 aten.hypot_.default, 233 aten.i0.default, 234 aten.i0.out, 235 aten.i0_.default, 236 aten.igamma.default, 237 aten.igamma.out, 238 aten.igamma_.default, 239 aten.igammac.default, 240 aten.igammac.out, 241 aten.igammac_.default, 242 aten.isinf.default, 243 aten.isnan.default, 244 aten.isneginf.default, 245 aten.isneginf.out, 246 aten.isposinf.default, 247 aten.isposinf.out, 248 aten.ldexp.default, 249 aten.ldexp.out, 250 aten.ldexp_.default, 251 aten.lt.Tensor, 252 aten.lt.Tensor_out, 253 aten.lt.Scalar, 254 aten.lt.Scalar_out, 255 aten.le.Scalar, 256 aten.le.Tensor, 257 aten.lerp.Scalar, 258 aten.lerp.Scalar_out, 259 aten.lerp.Tensor, 260 aten.lerp.Tensor_out, 261 aten.lerp_.Scalar, 262 aten.lerp_.Tensor, 263 aten.lgamma.default, 264 aten.lgamma.out, 265 aten.lgamma_.default, 266 aten.log.default, 267 aten.log.out, 268 aten.log10.default, 269 aten.log10.out, 270 aten.log10_.default, 271 aten.log1p.default, 272 aten.log1p.out, 273 aten.log1p_.default, 274 aten.log2.default, 275 aten.log2.out, 276 aten.log2_.default, 277 aten.log_.default, 278 aten.logaddexp.default, 279 aten.logaddexp.out, 280 aten.logaddexp2.default, 281 aten.logaddexp2.out, 282 aten.logical_and.default, 283 aten.logical_and.out, 284 aten.logical_and_.default, 285 aten.logical_not.default, 286 aten.logical_not.out, 287 aten.logical_not_.default, 288 aten.logical_or.default, 289 aten.logical_or.out, 290 aten.logical_or_.default, 291 aten.logical_xor.default, 292 aten.logical_xor.out, 293 aten.logical_xor_.default, 294 aten.logit.default, 295 aten.logit.out, 296 aten.logit_.default, 297 aten.masked_fill.Scalar, 298 aten.maximum.out, 299 aten.mul.Scalar, 300 aten.mul.Tensor, 301 aten.mul.out, 302 aten.mul_.Scalar, 303 aten.mul_.Tensor, 304 aten.mvlgamma.default, 305 aten.mvlgamma.out, 306 aten.mvlgamma_.default, 307 aten.native_dropout_backward.default, 308 aten.native_dropout_backward.out, 309 aten.nan_to_num.default, 310 aten.nan_to_num.out, 311 aten.nan_to_num_.default, 312 aten.ne.Scalar, 313 aten.neg.default, 314 aten.neg.out, 315 aten.neg_.default, 316 aten.nextafter.default, 317 aten.nextafter.out, 318 aten.nextafter_.default, 319 aten.polygamma.default, 320 aten.polygamma.out, 321 aten.polygamma_.default, 322 aten.positive.default, 323 aten.pow.Scalar, 324 aten.pow.Scalar_out, 325 aten.pow.Tensor_Scalar, 326 aten.pow.Tensor_Scalar_out, 327 aten.pow.Tensor_Tensor, 328 aten.pow.Tensor_Tensor_out, 329 aten.pow_.Scalar, 330 aten.pow_.Tensor, 331 aten.reciprocal.default, 332 aten.reciprocal.out, 333 aten.reciprocal_.default, 334 aten.rad2deg.default, 335 aten.rad2deg.out, 336 aten.rad2deg_.default, 337 aten.relu.default, 338 aten.relu_.default, 339 aten.remainder.Scalar, 340 aten.remainder.Scalar_Tensor, 341 aten.remainder.Scalar_out, 342 aten.remainder.Tensor, 343 aten.remainder.Tensor_out, 344 aten.remainder_.Scalar, 345 aten.remainder_.Tensor, 346 aten.round.decimals, 347 aten.round.decimals_out, 348 aten.round.default, 349 aten.round.out, 350 aten.round_.decimals, 351 aten.round_.default, 352 aten.rsqrt.default, 353 aten.rsqrt.out, 354 aten.rsqrt_.default, 355 aten.rsub.Scalar, 356 aten.sgn.default, 357 aten.sgn.out, 358 aten.sgn_.default, 359 aten.sigmoid.default, 360 aten.sigmoid.out, 361 aten.sigmoid_.default, 362 aten.sign.default, 363 aten.sign.out, 364 aten.sign_.default, 365 aten.signbit.default, 366 aten.signbit.out, 367 aten.silu.default, 368 aten.silu.out, 369 aten.sin.default, 370 aten.sin.out, 371 aten.sin_.default, 372 aten.sinc.default, 373 aten.sinc.out, 374 aten.sinc_.default, 375 aten.sinh.default, 376 aten.sinh.out, 377 aten.sinh_.default, 378 aten.sqrt.default, 379 aten.sqrt.out, 380 aten.sqrt_.default, 381 aten.square.default, 382 aten.square.out, 383 aten.square_.default, 384 aten.sub.Scalar, 385 aten.sub.Tensor, 386 aten.sub.out, 387 aten.sub_.Scalar, 388 aten.sub_.Tensor, 389 aten.tan.default, 390 aten.tan.out, 391 aten.tan_.default, 392 aten.tanh.default, 393 aten.tanh.out, 394 aten.tanh_.default, 395 aten.true_divide.Tensor, 396 aten.trunc.default, 397 aten.trunc.out, 398 aten.trunc_.default, 399 aten.where.self, 400 aten.where.self_out, 401 aten.xlogy.OutScalar_Self, 402 aten.xlogy.OutScalar_Other, 403 aten.xlogy.OutTensor, 404 aten.xlogy.Scalar_Other, 405 aten.xlogy.Scalar_Self, 406 aten.xlogy.Tensor, 407 aten.xlogy_.Scalar_Other, 408 aten.xlogy_.Tensor, 409 # backward point-wise ops 410 # please keep the entries below alphabetically sorted 411 aten.gelu_backward.default, 412 aten.sigmoid_backward.default, 413 aten.silu_backward.default, 414 aten.tanh_backward.default, 415 aten.threshold_backward.default, 416] 417 418 419def pointwise_strategy( 420 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False 421) -> OpStrategy: 422 max_shards_strategy_index = -1 423 max_shards = -1 424 425 if _is_inplace_op(op_schema.op): 426 # inplace op should follow the first arg strategy 427 followed_strategy = op_schema.args_schema[0] 428 elif _is_out_variant_op(op_schema.op): 429 # out variant op should follow the out kwarg strategy 430 followed_strategy = op_schema.kwargs_schema["out"] 431 else: 432 # normal pointwise op, we choose to follow the arg with 433 # the max shards in case operands needs reshard 434 for idx, arg_strategy in enumerate(op_schema.args_schema): 435 if not isinstance(arg_strategy, OpStrategy): 436 continue 437 438 arg_max_shards = arg_strategy.max_num_shards() 439 if arg_max_shards > max_shards: 440 max_shards_strategy_index = idx 441 max_shards = arg_max_shards 442 443 followed_strategy = op_schema.args_schema[max_shards_strategy_index] 444 445 assert isinstance( 446 followed_strategy, OpStrategy 447 ), f"no strategy to follow for {op_schema}!" 448 return common_pointwise_strategy( 449 mesh, op_schema.args_schema, followed_strategy, linearity 450 ) 451 452 453def common_pointwise_strategy( 454 mesh: DeviceMesh, 455 args_schema: Sequence[object], 456 followed_strategy: OpStrategy, 457 linearity: bool, 458) -> OpStrategy: 459 # handle broadcasting 460 common_shape = torch.broadcast_shapes( 461 *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] 462 ) 463 pointwise_strategy = OpStrategy([]) 464 465 for placement_strategy in followed_strategy.strategies: 466 spec_to_follow = placement_strategy.output_spec 467 out_placements: List[Placement] = [] 468 for placement in spec_to_follow.placements: 469 if isinstance(placement, Shard): 470 shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) 471 common_ndim = len(common_shape) 472 new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim 473 out_placements.append(Shard(new_shard_dim)) 474 elif isinstance(placement, Partial) and not linearity: 475 # clear the partial placemnet if op does not support linearity 476 # by default we just replicate the partial, need to see if this 477 # is optimal for all cases 478 out_placements.append(Replicate()) 479 else: 480 out_placements.append(placement) 481 482 input_specs: List[DTensorSpec] = [] 483 redistribute_costs: List[List[float]] = [] 484 for input_arg in args_schema: 485 if isinstance(input_arg, OpStrategy): 486 # every arg follow the out_placements, but need to handle broadcasting 487 input_arg_spec = input_arg.strategies[0].output_spec 488 input_arg_dims_map = infer_broadcast_dims_map( 489 common_shape, input_arg_spec.shape 490 ) 491 input_target_placements = map_placements_after_broadcast( 492 tuple(out_placements), 493 common_shape, 494 input_arg_dims_map, 495 ) 496 input_arg_target_spec = DTensorSpec( 497 mesh=mesh, 498 placements=input_target_placements, 499 tensor_meta=input_arg_spec.tensor_meta, 500 ) 501 input_specs.append(input_arg_target_spec) 502 redistribute_costs.append( 503 generate_redistribute_costs(input_arg, input_arg_target_spec) 504 ) 505 506 pointwise_strategy.strategies.append( 507 PlacementStrategy( 508 output_specs=DTensorSpec( 509 mesh=mesh, 510 placements=tuple(out_placements), 511 ), 512 input_specs=input_specs, 513 redistribute_cost=redistribute_costs, 514 ) 515 ) 516 return pointwise_strategy 517 518 519def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 520 """ 521 Linear pointwise operators can propagate pending reductions. 522 For example, c = add(a, b); if a is pending sum, then c will be 523 pending sum as well without any communication overhead. 524 """ 525 return pointwise_strategy(mesh, op_schema, linearity=True) 526 527 528for op in linear_pointwise_ops: 529 register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( 530 linear_pointwise_strategy 531 ) 532 533for op in pointwise_ops: 534 register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( 535 pointwise_strategy 536 ) 537 538 539# TODO: add all for_each ops 540for_each_ops = [ 541 aten._foreach_abs.default, 542 aten._foreach_abs_.default, 543 aten._foreach_addcdiv_.Scalar, 544 aten._foreach_addcdiv_.ScalarList, 545 aten._foreach_addcdiv_.Tensor, 546 aten._foreach_addcmul.Scalar, 547 aten._foreach_addcmul_.Scalar, 548 aten._foreach_addcmul_.ScalarList, 549 aten._foreach_addcmul_.Tensor, 550 aten._foreach_clamp_max_.Scalar, 551 aten._foreach_clamp_min_.Scalar, 552 aten._foreach_div_.List, 553 aten._foreach_div_.Scalar, 554 aten._foreach_div_.ScalarList, 555 aten._foreach_div_.Tensor, 556 aten._foreach_div.List, 557 aten._foreach_div.Scalar, 558 aten._foreach_div.ScalarList, 559 aten._foreach_div.Tensor, 560 aten._foreach_lerp_.Scalar, 561 aten._foreach_maximum_.List, 562 aten._foreach_mul.Scalar, 563 aten._foreach_mul.ScalarList, 564 aten._foreach_mul.Tensor, 565 aten._foreach_mul.List, 566 aten._foreach_mul_.Scalar, 567 aten._foreach_mul_.ScalarList, 568 aten._foreach_mul_.Tensor, 569 aten._foreach_mul_.List, 570 aten._foreach_neg.default, 571 aten._foreach_neg_.default, 572 aten._foreach_reciprocal_.default, 573 aten._foreach_sub.Scalar, 574 aten._foreach_sub_.Scalar, 575 aten._foreach_sub.List, 576 aten._foreach_sub_.List, 577 aten._foreach_sub.ScalarList, 578 aten._foreach_sub_.ScalarList, 579 aten._foreach_sqrt.default, 580 aten._foreach_sqrt_.default, 581 aten._foreach_zero_.default, 582 aten._foreach_exp.default, 583 aten._foreach_exp_.default, 584 aten._foreach_cos.default, 585 aten._foreach_cos_.default, 586 aten._foreach_log.default, 587 aten._foreach_log_.default, 588 aten._amp_foreach_non_finite_check_and_unscale_.default, 589] 590 591for_each_linearity_ops = [ 592 aten._foreach_add.Scalar, 593 aten._foreach_add_.Scalar, 594 aten._foreach_add_.ScalarList, 595 aten._foreach_add.List, 596 aten._foreach_add_.List, 597] 598 599 600def list_pointwise_strategy( 601 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False 602) -> StrategyType: 603 """ 604 Apply the pointwise strategy to the zipped arguments. For example, if we 605 run a foreach add of two lists l1 and l2, then we apply the pointwise 606 strategy on each pair (l1[i], l2[i]). If the first argument is a list but 607 the second (or later) one is a tensor, then we broadcast the tensor by 608 replicating it into a list with the length of the first argument. 609 610 Args: 611 mesh (DeviceMesh): device mesh for pointwise ops 612 op_schema (OpSchema): schema of the operator to generate strategy for 613 linearity (bool): specify whether op(a) + op(b) = op(a + b) 614 615 Returns: 616 OpStrategy: generated strategy 617 """ 618 619 def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]: 620 first_arg = args_schema[0] 621 assert isinstance(first_arg, TupleStrategy) 622 strategy_len = len(first_arg.childs) 623 tuple_strategies: List[TupleStrategy] = [] 624 for arg_idx, arg in enumerate(args_schema): 625 if isinstance(arg, TupleStrategy): 626 # every tuple strategy should have the same length 627 assert len(arg.childs) == strategy_len 628 tuple_strategies.append(arg) 629 elif isinstance(arg, OpStrategy): 630 if arg_idx > 0: # implicitly broadcast 631 tuple_strategies.append( 632 TupleStrategy([arg for _ in range(strategy_len)]) 633 ) 634 else: 635 raise RuntimeError( 636 f"list op only supports tuple strategy! {op_schema}" 637 ) 638 return tuple_strategies 639 640 args_strategies = args_tuple_strategies(op_schema.args_schema) 641 follow_strategy: TupleStrategy = args_strategies[0] 642 list_strategy: List[OpStrategy] = [] 643 for child_idx, child_strtgy in enumerate(follow_strategy.childs): 644 assert isinstance(child_strtgy, OpStrategy) 645 args_schema: List[StrategyType] = [ 646 arg_strategy.childs[child_idx] for arg_strategy in args_strategies 647 ] 648 pointwise_strategy: OpStrategy = common_pointwise_strategy( 649 mesh, args_schema, child_strtgy, linearity 650 ) 651 list_strategy.append(pointwise_strategy) 652 return TupleStrategy(list_strategy) 653 654 655def list_linear_pointwise_strategy( 656 mesh: DeviceMesh, op_schema: OpSchema 657) -> StrategyType: 658 """ 659 for each list op stratgy that supports linearity 660 """ 661 return list_pointwise_strategy(mesh, op_schema, linearity=True) 662 663 664for op in for_each_ops: 665 register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( 666 list_pointwise_strategy 667 ) 668 669for op in for_each_linearity_ops: 670 register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( 671 list_linear_pointwise_strategy 672 ) 673 674fused_ops = [ 675 aten._fused_adam_.default, 676 aten._fused_adam.default, 677 aten._fused_adam.tensor_lr, 678 aten._fused_adam_.tensor_lr, 679 aten._fused_adamw_.default, 680 aten._fused_adamw.default, 681 aten._fused_adamw.tensor_lr, 682 aten._fused_adamw_.tensor_lr, 683] 684 685for op in fused_ops: 686 register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( 687 list_pointwise_strategy 688 ) 689