• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/permutation_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/math/math_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/statusor.h"
43 
44 namespace xla {
45 namespace {
46 
47 using absl::StrFormat;
48 using absl::StrJoin;
49 
50 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64_t> slice)51 bool AllUnique(absl::Span<const int64_t> slice) {
52   return std::set<int64_t>(slice.begin(), slice.end()).size() == slice.size();
53 }
54 
ExpectArray(const Shape & shape,absl::string_view op_type)55 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
56   if (!shape.IsArray()) {
57     return InvalidArgument("Expected array argument for %s, but got %s.",
58                            std::string(op_type), ShapeUtil::HumanString(shape));
59   }
60   return OkStatus();
61 }
62 
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64_t inputs)63 Status VerifyReducerShape(const ProgramShape& reducer_shape,
64                           absl::Span<const Shape* const> init_value_shapes,
65                           absl::Span<const PrimitiveType> input_element_types,
66                           int64_t inputs) {
67   if (reducer_shape.parameters_size() != inputs * 2) {
68     return InvalidArgument(
69         "Reduction function must take %d parameters, but "
70         "takes %d parameter(s).",
71         inputs * 2, reducer_shape.parameters_size());
72   }
73 
74   const Shape& accumulator_shape = reducer_shape.result();
75   std::vector<const Shape*> accumulator_subshapes;
76   if (accumulator_shape.IsArray()) {
77     if (inputs != 1) {
78       return InvalidArgument(
79           "Reduction function must produce a tuple with %d elements, but "
80           "produces a scalar",
81           inputs);
82     }
83     accumulator_subshapes.push_back(&accumulator_shape);
84   } else if (accumulator_shape.IsTuple()) {
85     if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
86       return InvalidArgument(
87           "Reduction function must produce a tuple with %d elements, but has "
88           "%d elements",
89           inputs, ShapeUtil::TupleElementCount(accumulator_shape));
90     }
91     for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
92       accumulator_subshapes.push_back(&element_shape);
93     }
94   } else {
95     return InvalidArgument(
96         "Reduction function must produce a scalar or tuple of scalars, but has "
97         "shape: %s",
98         ShapeUtil::HumanString(accumulator_shape));
99   }
100 
101   for (const Shape* element_shape : accumulator_subshapes) {
102     if (element_shape->rank() != 0) {
103       return InvalidArgument(
104           "Reduction function must return a scalar or tuple of scalars but "
105           "returns shape: %s",
106           ShapeUtil::HumanString(accumulator_shape));
107     }
108   }
109 
110   for (int64_t i = 0; i < inputs; ++i) {
111     // Check that the accumulator can be passed in as the first argument.
112     // Note: comparing here and below with Compatible since we don't care about
113     // layout in scalars - see b/26668201 for a longer-term vision.
114     if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
115                                reducer_shape.parameters(i))) {
116       return InvalidArgument(
117           "Reduction function's %d-th parameter shape differs from the "
118           "result shape: %s vs %s",
119           i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
120           ShapeUtil::HumanString(*accumulator_subshapes[i]));
121     }
122     // Check that init_value's shapes are suitable for reducer_shape.
123     if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
124                                                   *init_value_shapes[i])) {
125       return InvalidArgument(
126           "Reduction function's accumulator shape at index %d differs from "
127           "the init_value shape: %s vs %s",
128           i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
129           ShapeUtil::HumanString(*init_value_shapes[i]));
130     }
131     // Check that the inputs can be passed in as the non-accumulator arguments.
132     const Shape input_element_shape =
133         ShapeUtil::MakeShape(input_element_types[i], {});
134     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
135             input_element_shape, reducer_shape.parameters(inputs + i))) {
136       return InvalidArgument(
137           "Reduction function's %d-th parameter shape differs from the "
138           "input type element type: %s vs %s",
139           inputs + i,
140           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
141           ShapeUtil::HumanString(input_element_shape));
142     }
143     // Check that the accumulator and inputs to the reducer function match.
144     // If the accumulator is scalar, it must have the same type as the inputs
145     // (up to fp precision). If it is a tuple, then the k-th element of the
146     // tuple must have the same type as the K-th input (again, up to fp
147     // precision.)
148     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
149             *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
150       return InvalidArgument(
151           "Reduction function's %d-th parameter shape must "
152           "match the result shape, but got %s vs %s.",
153           inputs + i,
154           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
155           ShapeUtil::HumanString(*accumulator_subshapes[i]));
156     }
157   }
158 
159   return OkStatus();
160 }
161 
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type)162 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
163                                        const Window& window,
164                                        PrimitiveType element_type) {
165   if (window.dimensions_size() != base_shape.rank()) {
166     return InvalidArgument(
167         "Window has dimension %d but base shape has dimension %d.",
168         window.dimensions_size(), base_shape.rank());
169   }
170 
171   std::vector<int64_t> output_dimensions(window.dimensions_size());
172   std::vector<bool> output_is_dynamic(window.dimensions_size());
173   for (int64_t i = 0; i < window.dimensions_size(); ++i) {
174     const auto& dim = window.dimensions(i);
175     if (dim.size() <= 0) {
176       return InvalidArgument("Window %s has a non-positive dimension.",
177                              window.DebugString());
178     }
179     if (dim.stride() <= 0) {
180       return InvalidArgument("Window %s has a non-positive stride.",
181                              window.DebugString());
182     }
183     if (dim.base_dilation() < 1) {
184       return InvalidArgument(
185           "Window %s has a non-positive base area dilation factor.",
186           window.DebugString());
187     }
188     if (dim.window_dilation() < 1) {
189       return InvalidArgument(
190           "Window %s has a non-positive window dilation factor.",
191           window.DebugString());
192     }
193 
194     const int64_t dilated_base = window_util::DilatedBound(
195         ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
196     const int64_t padded_dilated_base =
197         dim.padding_low() + dilated_base + dim.padding_high();
198     const int64_t dilated_window =
199         window_util::DilatedBound(dim.size(), dim.window_dilation());
200 
201     output_dimensions[i] = window_util::StridedBound(
202         padded_dilated_base, dilated_window, dim.stride());
203     output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
204   }
205 
206   return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
207                                        output_is_dynamic);
208 }
209 
MaybeUpcast(PrimitiveType from_type,std::optional<PrimitiveType> preferred_element_type)210 StatusOr<PrimitiveType> MaybeUpcast(
211     PrimitiveType from_type,
212     std::optional<PrimitiveType> preferred_element_type) {
213   if (!preferred_element_type.has_value() ||
214       *preferred_element_type == from_type) {
215     return from_type;
216   }
217   if (!primitive_util::IsFloatingPointType(from_type) &&
218       primitive_util::BitWidth(*preferred_element_type) <
219           primitive_util::BitWidth(from_type)) {
220     return InvalidArgument(
221         "`preferred_element_type` must not be narrower than the original "
222         "type.");
223   }
224   return *preferred_element_type;
225 }
226 
227 }  // namespace
228 
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)229 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
230     HloOpcode opcode, const HloInstruction* operand) {
231   return InferUnaryOpShape(opcode, operand->shape());
232 }
233 
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)234 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
235     HloOpcode opcode, const Shape& shape) {
236   // There is no copy operation at the proto level, so handle copy explicitly.
237   // A domain shape is the same as the input one.
238   if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
239     return shape;
240   }
241 
242   TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
243 
244   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
245   switch (opcode) {
246     case HloOpcode::kFloor:
247     case HloOpcode::kCeil:
248     case HloOpcode::kRoundNearestAfz:
249     case HloOpcode::kRoundNearestEven:
250       if (!ShapeUtil::ElementIsFloating(shape)) {
251         return InvalidArgument(
252             "Expected element type in shape to be floating for %s operation; "
253             "got %s.",
254             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
255       }
256       return shape;
257     case HloOpcode::kCos:
258     case HloOpcode::kSin:
259     case HloOpcode::kExp:
260     case HloOpcode::kExpm1:
261     case HloOpcode::kLog:
262     case HloOpcode::kLog1p:
263     case HloOpcode::kLogistic:
264     case HloOpcode::kRsqrt:
265     case HloOpcode::kSqrt:
266     case HloOpcode::kCbrt:
267     case HloOpcode::kTanh:
268       if (!ShapeUtil::ElementIsFloating(shape) &&
269           !ShapeUtil::ElementIsComplex(shape)) {
270         return InvalidArgument(
271             "Expected element type in shape to be floating or complex for %s "
272             "operation; got %s.",
273             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
274       }
275       return shape;
276     case HloOpcode::kReal:
277     case HloOpcode::kImag:
278       if (ShapeUtil::ElementIsComplex(shape)) {
279         return ShapeUtil::ComplexComponentShape(shape);
280       } else if (ShapeUtil::ElementIsFloating(shape)) {
281         return shape;
282       } else {
283         return InvalidArgument(
284             "Expected element type in shape to be floating or complex for "
285             "%s operation; got %s.",
286             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
287       }
288     case HloOpcode::kAbs:
289       if (ShapeUtil::ElementIsComplex(shape)) {
290         return ShapeUtil::ChangeElementType(
291             shape, primitive_util::ComplexComponentType(shape.element_type()));
292       } else if (ShapeUtil::ElementIsSigned(shape)) {
293         return shape;
294       } else {
295         return InvalidArgument(
296             "Expected element type in shape to be floating or complex for "
297             "%s operation; got %s.",
298             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
299       }
300     case HloOpcode::kClz:
301       if (!ShapeUtil::ElementIsIntegral(shape)) {
302         return InvalidArgument(
303             "Expected an integral element type in argument to Clz "
304             "operation; got %s.",
305             PrimitiveType_Name(shape.element_type()));
306       }
307       return shape;
308     case HloOpcode::kNegate:
309       if (!ShapeUtil::ElementIsIntegral(shape) &&
310           !ShapeUtil::ElementIsFloating(shape) &&
311           !ShapeUtil::ElementIsComplex(shape)) {
312         return InvalidArgument(
313             "Expected element type in shape to be integral, floating or "
314             "complex for %s operation; got %s.",
315             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
316       }
317       return shape;
318     case HloOpcode::kPopulationCount:
319       if (!ShapeUtil::ElementIsIntegral(shape)) {
320         return InvalidArgument(
321             "Expected an integral element type in argument to PopulationCount "
322             "operation; got %s.",
323             PrimitiveType_Name(shape.element_type()));
324       }
325       return shape;
326     case HloOpcode::kSign:
327       if (!ShapeUtil::ElementIsSigned(shape) &&
328           !ShapeUtil::ElementIsComplex(shape)) {
329         return InvalidArgument(
330             "Expected element type in shape to be signed or complex for "
331             "%s operation; got %s.",
332             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
333       }
334       return shape;
335 
336     case HloOpcode::kNot:
337       if (shape.element_type() != PRED &&
338           !primitive_util::IsIntegralType(shape.element_type())) {
339         return InvalidArgument(
340             "Expected pred or an integral element type in argument to Not "
341             "operation; got %s.",
342             PrimitiveType_Name(shape.element_type()));
343       }
344       return shape;
345 
346     case HloOpcode::kIsFinite:
347       if (!ShapeUtil::ElementIsFloating(shape)) {
348         return InvalidArgument(
349             "Expected element type in shape to be floating "
350             "point for IsFinite "
351             "operation; got %s.",
352             PrimitiveType_Name(shape.element_type()));
353       }
354       return ShapeUtil::ChangeElementType(shape, PRED);
355 
356     default:
357       return InvalidArgument(
358           "Unknown operation for unary shape inference: \"%s\".",
359           HloOpcodeString(opcode));
360   }
361 }
362 
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64_t dimension)363 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
364     absl::Span<const Shape* const> arg_shapes, const int64_t dimension) {
365   if (arg_shapes.empty()) {
366     return InvalidArgument("Concatenate expects at least one argument.");
367   }
368   if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
369     return InvalidArgument("Concatenate dimension out of bounds: %d.",
370                            dimension);
371   }
372   const Shape* arg_shape = nullptr;
373   PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
374   for (const Shape* shape : arg_shapes) {
375     TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
376     if (!arg_shape) {
377       arg_shape = shape;
378       element_type = arg_shape->element_type();
379       continue;
380     }
381     if (arg_shape->rank() != shape->rank()) {
382       return InvalidArgument(
383           "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
384           "(%s).",
385           arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
386           ShapeUtil::HumanString(*shape));
387     }
388     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
389       return InvalidArgument(
390           "Cannot concatenate arrays with different element types: %s vs %s.",
391           PrimitiveType_Name(arg_shape->element_type()),
392           PrimitiveType_Name(shape->element_type()));
393     }
394     for (int64_t dimension_number = 0; dimension_number < arg_shape->rank();
395          ++dimension_number) {
396       if (arg_shape->dimensions(dimension_number) !=
397           shape->dimensions(dimension_number)) {
398         if (dimension_number == dimension) {
399           continue;  // It's okay to differ in the dimension we're
400                      // concatenating.
401         }
402         return InvalidArgument(
403             "Cannot concatenate arrays that differ in dimensions other than "
404             "the one being concatenated. Dimension %d in both shapes must be "
405             "equal: %s vs %s.",
406             dimension_number, ShapeUtil::HumanString(*arg_shape),
407             ShapeUtil::HumanString(*shape));
408       }
409     }
410     element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
411   }
412 
413   std::vector<int64_t> new_dimensions(arg_shape->dimensions().begin(),
414                                       arg_shape->dimensions().end());
415   for (size_t i = 1; i < arg_shapes.size(); ++i) {
416     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
417   }
418 
419   Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
420 
421   // Set dynamic dimensions if any input has dynamic dimension.
422   for (const Shape* shape : arg_shapes) {
423     for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
424       if (shape->is_dynamic_dimension(i)) {
425         result.set_dynamic_dimension(i, true);
426       }
427     }
428   }
429   return result;
430 }
431 
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)432 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
433     const Shape& operand_shape, PrimitiveType new_element_type) {
434   if (!operand_shape.IsArray() ||
435       !primitive_util::IsArrayType(new_element_type)) {
436     // Note: we may want to support tuple conversions via this operation in the
437     // future, by recursing into the tuple elements to check all sub-conversions
438     // are valid. For now we just reject them, though.
439     return InvalidArgument(
440         "Convert does not allow non-arrays, so cannot convert from %s to %s.",
441         ShapeUtil::HumanString(operand_shape),
442         PrimitiveType_Name(new_element_type));
443   }
444 
445   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
446 }
447 
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)448 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
449     const Shape& operand_shape, PrimitiveType new_element_type) {
450   auto old_element_type = operand_shape.element_type();
451   if (primitive_util::IsComplexType(old_element_type) !=
452       primitive_util::IsComplexType(new_element_type)) {
453     return InvalidArgument("Conversion between complex and real type %s => %s.",
454                            ShapeUtil::HumanString(operand_shape),
455                            PrimitiveType_Name(new_element_type));
456   }
457   if (!operand_shape.IsArray() ||
458       !primitive_util::IsArrayType(new_element_type)) {
459     // Note: we may want to support tuple conversions via this operation in the
460     // future, by recursing into the tuple elements to check all sub-conversions
461     // are valid. For now we just reject them, though.
462     return InvalidArgument(
463         "Cannot convert from or to tuple type; requested conversion: %s => %s.",
464         ShapeUtil::HumanString(operand_shape),
465         PrimitiveType_Name(new_element_type));
466   }
467 
468   int input_bitwidth = primitive_util::BitWidth(old_element_type);
469   int output_bitwidth = primitive_util::BitWidth(new_element_type);
470   if (std::max(input_bitwidth, output_bitwidth) %
471           std::min(input_bitwidth, output_bitwidth) !=
472       0) {
473     return InvalidArgument(
474         "Cannot bitcast types with undivisible bit-widths: %s => %s.",
475         PrimitiveType_Name(old_element_type),
476         PrimitiveType_Name(new_element_type));
477   }
478   int ratio = std::max(output_bitwidth, input_bitwidth) /
479               std::min(output_bitwidth, input_bitwidth);
480 
481   Shape new_shape = operand_shape;
482   new_shape.set_element_type(new_element_type);
483   if (input_bitwidth > output_bitwidth) {
484     ShapeUtil::AppendMinorDimension(ratio, &new_shape);
485   } else if (input_bitwidth < output_bitwidth) {
486     int last_dimension_idx = operand_shape.dimensions_size() - 1;
487     if (operand_shape.dimensions_size() < 1 ||
488         operand_shape.dimensions(last_dimension_idx) != ratio) {
489       return InvalidArgument(
490           "Last dimension of input shape=%d is not equal to ratio of "
491           "bit-widths=%d "
492           "for bitcast-convert from %s to %s",
493           operand_shape.dimensions(last_dimension_idx), ratio,
494           ShapeUtil::HumanString(operand_shape),
495           PrimitiveType_Name(new_element_type));
496     }
497     new_shape.DeleteDimension(last_dimension_idx);
498   }
499   return new_shape;
500 }
501 
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)502 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
503     const Shape& operand_shape, const int exponent_bits,
504     const int mantissa_bits) {
505   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
506     return InvalidArgument(
507         "Expected element type in shape to be floating point for "
508         "ReducePrecision operation; got %s.",
509         PrimitiveType_Name(operand_shape.element_type()));
510   }
511   if (exponent_bits < 1) {
512     // One exponent bit is necessary to distinguish 0 from infinity.  Having
513     // no exponent bits doesn't produce a sensible number, so we require at
514     // least one.
515     return InvalidArgument("Expected exponent_bits >= 1; got %d.",
516                            exponent_bits);
517   }
518   if (mantissa_bits < 0) {
519     // A number with no mantissa bits is still meaningful, however.
520     return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
521                            mantissa_bits);
522   }
523   return operand_shape;
524 }
525 
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)526 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
527     const Shape& operand_shape, const Shape& padding_value_shape,
528     const PaddingConfig& padding_config) {
529   if (!operand_shape.IsArray()) {
530     return InvalidArgument(
531         "Pad operation does not support tuple-shape operands.");
532   }
533   if (!ShapeUtil::IsScalar(padding_value_shape)) {
534     return InvalidArgument(
535         "Pad operation does not support non-scalar padding values.");
536   }
537   if (operand_shape.rank() != padding_config.dimensions_size()) {
538     return InvalidArgument(
539         "The rank of the operand and the padding configuration do not match: "
540         "%s vs %s.",
541         ShapeUtil::HumanString(operand_shape),
542         padding_config.ShortDebugString());
543   }
544   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
545                                                      padding_value_shape)) {
546     return InvalidArgument(
547         "The element types of the operands to Pad do not match.");
548   }
549   if (absl::c_any_of(padding_config.dimensions(),
550                      [](const PaddingConfig::PaddingConfigDimension& p) {
551                        return p.interior_padding() < 0;
552                      })) {
553     return InvalidArgument("Interior padding cannot be negative: %s",
554                            padding_config.ShortDebugString());
555   }
556 
557   if (!padding_value_shape.is_static()) {
558     return InvalidArgument("Dynamic padding value is not supported");
559   }
560 
561   std::vector<int64_t> dimensions(operand_shape.rank());
562   std::vector<bool> is_dynamic(operand_shape.rank());
563   for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
564     const auto& p = padding_config.dimensions(i);
565     dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
566                     p.edge_padding_high() +
567                     std::max<int64_t>(operand_shape.dimensions(i) - 1, 0LL) *
568                         p.interior_padding();
569     if (dimensions[i] < 0) {
570       return InvalidArgument("Padding result in negative size for dimension %d",
571                              i);
572     }
573     is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
574   }
575 
576   return ShapeUtil::MakeShape(
577       ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
578       dimensions, is_dynamic);
579 }
580 
581 // Current DotDimensionNumbers Requirements:
582 //
583 // Contracting Dimensions:
584 // *) Same number of contracting dimensions on both lhs and rhs.
585 // *) Contracting dimension size must be the same on both lhs and rhs.
586 //
587 // Batch Dimensions:
588 // *) Same number of batch dimensions on both lhs and rhs.
589 // *) Same batch dimension sizes on both lhs and rhs.
590 //
591 
592 namespace {
593 
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)594 Status ValidateDotDimensionNumbers(
595     const Shape& lhs, const Shape& rhs,
596     const DotDimensionNumbers& dimension_numbers) {
597   // Check that dimension numbers are in range.
598   auto dims_in_range = [](const int64_t rank,
599                           absl::Span<const int64_t> contracting_dims,
600                           absl::Span<const int64_t> batch_dims) -> bool {
601     auto in_range = [&rank](int64_t i) -> bool { return 0 <= i && i < rank; };
602     return absl::c_all_of(contracting_dims, in_range) &&
603            absl::c_all_of(batch_dims, in_range);
604   };
605 
606   absl::Span<const int64_t> lhs_contracting_dimensions =
607       dimension_numbers.lhs_contracting_dimensions();
608   absl::Span<const int64_t> rhs_contracting_dimensions =
609       dimension_numbers.rhs_contracting_dimensions();
610   absl::Span<const int64_t> lhs_batch_dimensions =
611       dimension_numbers.lhs_batch_dimensions();
612   absl::Span<const int64_t> rhs_batch_dimensions =
613       dimension_numbers.rhs_batch_dimensions();
614 
615   if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
616                      lhs_batch_dimensions) ||
617       !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
618                      rhs_batch_dimensions)) {
619     return InvalidArgument("A dimension number is out of range in Dot: %s.",
620                            dimension_numbers.DebugString());
621   }
622 
623   // Check that dimension numbers are unique.
624   auto dims_unique = [](absl::Span<const int64_t> contracting_dims,
625                         absl::Span<const int64_t> batch_dims) -> bool {
626     absl::flat_hash_set<int64_t> dim_set;
627     auto is_unique = [&dim_set](int64_t i) -> bool {
628       return dim_set.insert(i).second;
629     };
630     return absl::c_all_of(contracting_dims, is_unique) &&
631            absl::c_all_of(batch_dims, is_unique);
632   };
633 
634   if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
635       !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
636     return InvalidArgument("A dimension number is not unique in Dot: %s.",
637                            dimension_numbers.DebugString());
638   }
639 
640   return OkStatus();
641 }
642 
643 }  // namespace
644 
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers,std::optional<PrimitiveType> preferred_element_type)645 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
646     const Shape& lhs, const Shape& rhs,
647     const DotDimensionNumbers& dimension_numbers,
648     std::optional<PrimitiveType> preferred_element_type) {
649   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
650   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
651 
652   auto fail = [lhs, rhs](const std::string& addendum) -> Status {
653     std::string message =
654         StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
655                   ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
656     if (!addendum.empty()) {
657       message += " " + addendum;
658     }
659     return InvalidArgument("%s", message);
660   };
661 
662   // Validate basic properties of dot dimension numbers.
663   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
664 
665   // Check that number of contracting dimensions match.
666   if (dimension_numbers.lhs_contracting_dimensions_size() !=
667       dimension_numbers.rhs_contracting_dimensions_size()) {
668     return fail(
669         "Must specify the same number of contracting dimensions for lhs and "
670         "rhs.");
671   }
672   // Check that contracting dimension sizes match.
673   for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
674        ++i) {
675     const int64_t lhs_contracting_dimension =
676         dimension_numbers.lhs_contracting_dimensions(i);
677     const int64_t rhs_contracting_dimension =
678         dimension_numbers.rhs_contracting_dimensions(i);
679     if (lhs.dimensions(lhs_contracting_dimension) !=
680         rhs.dimensions(rhs_contracting_dimension)) {
681       return fail("Contracting dimension sizes do not match.");
682     }
683   }
684 
685   // Check that number of batch dimensions match.
686   if (dimension_numbers.lhs_batch_dimensions_size() !=
687       dimension_numbers.rhs_batch_dimensions_size()) {
688     return fail("Must the same number of batch dimensions for lhs and rhs.");
689   }
690 
691   // Check that batch dimension numbers and sizes match.
692   for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
693     if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
694         rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
695       return fail("Batch dimension sizes must match for lhs/rhs.");
696     }
697   }
698 
699   // The ranks of lhs and rhs are decremented by 1 respectively due to the
700   // contraction, and added for the rank of the result. When an input tensor is
701   // a scalar, its contribution to the rank of the result is 0.
702   // Generate the result dimensions in order, rhs dimensions followed by lhs
703   // dimensions except the contracted and batch dimensions.
704   std::vector<int64_t> dimensions;
705   std::vector<bool> is_dynamic;
706   const auto& lhs_batch_dimensions = dimension_numbers.lhs_batch_dimensions();
707   const auto lhs_batch_dimensions_size =
708       lhs.rank() - dimension_numbers.lhs_contracting_dimensions().size() +
709       rhs.rank() - dimension_numbers.rhs_contracting_dimensions().size() -
710       dimension_numbers.rhs_batch_dimensions().size();
711   dimensions.reserve(lhs_batch_dimensions_size);
712   is_dynamic.reserve(lhs_batch_dimensions_size);
713   for (const int64_t lhs_dim : lhs_batch_dimensions) {
714     dimensions.push_back(lhs.dimensions(lhs_dim));
715     is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
716   }
717   for (int64_t i = 0; i < lhs.rank(); i++) {
718     if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
719                                i) &&
720         !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
721       dimensions.push_back(lhs.dimensions(i));
722       is_dynamic.push_back(lhs.is_dynamic_dimension(i));
723     }
724   }
725   for (int64_t i = 0; i < rhs.rank(); i++) {
726     if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
727                                i) &&
728         !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
729       dimensions.push_back(rhs.dimensions(i));
730       is_dynamic.push_back(rhs.is_dynamic_dimension(i));
731     }
732   }
733   TF_ASSIGN_OR_RETURN(
734       PrimitiveType type,
735       MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
736                   preferred_element_type));
737   Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
738 
739   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
740   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
741   return result;
742 }
743 
744 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)745 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
746                                                        const Shape& lhs,
747                                                        const Shape& rhs) {
748   TF_RET_CHECK(lhs.rank() == rhs.rank());
749 
750   // The shapes have to be compatible. That is, if some dimension d has a
751   // different size in the two shapes, one of them has to be 1 (a "degenerate"
752   // dimension). In that case, the output shape has the non-1 dimension size
753   // from the lhs/rhs pair in every index.
754   std::vector<int64_t> output_dimensions(lhs.rank());
755   std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
756   for (int64_t i = 0; i < lhs.rank(); ++i) {
757     if (lhs.dimensions(i) == rhs.dimensions(i)) {
758       output_dimensions[i] = lhs.dimensions(i);
759     } else if (lhs.dimensions(i) == 1) {
760       output_dimensions[i] = rhs.dimensions(i);
761     } else if (rhs.dimensions(i) == 1) {
762       output_dimensions[i] = lhs.dimensions(i);
763     } else {
764       return InvalidArgument(
765           "Binary op %s with incompatible shapes: %s and %s.",
766           HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
767           ShapeUtil::HumanString(rhs));
768     }
769   }
770 
771   // Merge dynamic dimensions from two shapes.
772   for (int64_t i = 0; i < rhs.rank(); ++i) {
773     if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
774       output_dimensions_is_dynamic[i] = true;
775     }
776   }
777 
778   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
779                               output_dimensions, output_dimensions_is_dynamic);
780 }
781 
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64_t> broadcast_dimensions)782 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
783     const Shape& smaller_shape, const Shape& larger_shape,
784     absl::Span<const int64_t> broadcast_dimensions) {
785   if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
786     // Reject "magic" inference for binops on different shapes, requiring
787     // the user to provide an explicit broadcast dimension in this case.
788     // See b/25177275 for more details.
789     return InvalidArgument("Shapes must be equal rank, but are %s and %s",
790                            ShapeUtil::HumanString(smaller_shape),
791                            ShapeUtil::HumanString(larger_shape));
792   } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
793     return InvalidArgument(
794         "Size of broadcast_dimensions has to match lower-rank operand's "
795         "rank; "
796         " lower-rank operand's rank is %d, size of broadcast_dimensions is "
797         "%u.",
798         smaller_shape.rank(), broadcast_dimensions.size());
799   }
800 
801   // broadcast_dimensions is a sequence of dimensions; its length is equal to
802   // the rank of the lower-rank operand. The lower-rank operand's dimensions
803   // have to be compatible with the higher-rank operand's dimensions at indices
804   // specified by broadcast_dimensions. Here compatible means the dimension
805   // sizes are equal or in one of the shapes the dimension size is
806   // one. Examples:
807   //
808   // smaller_shape   larger_shape   broadcast_dimensions   output_shape
809   //   []              [2, 3]          {}                    [2, 3]
810   //   [3]             [4, 3]          {1}                   [4, 3]
811   //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
812   //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
813   //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
814   //
815   // The column output_shape may not be the final shape of the XLA
816   // operation. After the "InDim" broadcasting implemented in this function
817   // expands the rank, degenerate-dimension broadcasting (implemented in
818   // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
819   // up to match the dimension size of the other operand. For example, consider
820   // the row in the table above with a smaller_shape of [2, 1]. The shape
821   // returned by this function is [2, 3, 1] (output_shape) however, the result
822   // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
823   // broadcasting.
824   //
825   // Invalid broadcasts:
826   //
827   // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
828   // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
829   //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
830   //   in broadcast_dimensions.
831   //
832   // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
833   // Reason: Dimension one of larger_shape (size 3) is not compatible with
834   //   dimension zero of smaller_shape(size 2)
835 
836   // The output shape is initially the larger_shape. Sizes of dimensions
837   // specified in broadcast_dimensions are then changed to match the
838   // corresponding dimension size in smaller_shape.
839   Shape output_shape(larger_shape);
840   output_shape.set_element_type(
841       ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
842 
843   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
844     int64_t dimension_to_match = broadcast_dimensions.at(i);
845     if (dimension_to_match < 0) {
846       return InvalidArgument(
847           "Broadcast dimension number (%d) cannot be negative.",
848           dimension_to_match);
849     }
850     if (dimension_to_match >= larger_shape.dimensions_size()) {
851       return InvalidArgument(
852           "Broadcast dimension number (%d) too large; higher-rank "
853           "operand has rank %d.",
854           dimension_to_match, larger_shape.dimensions_size());
855     }
856     int64_t small_dimension_size = smaller_shape.dimensions(i);
857     int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match);
858     bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
859     bool large_is_dynamic =
860         larger_shape.is_dynamic_dimension(dimension_to_match);
861     // Dimension sizes must be compatible: match or be degenerate (degenerate
862     // case is handled by degenerate dimension broadcasting which occurs after
863     // InDim broadcasting).
864     if (small_dimension_size != large_dimension_size &&
865         small_dimension_size != 1 && large_dimension_size != 1) {
866       return InvalidArgument(
867           "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
868           small_dimension_size, large_dimension_size,
869           ShapeUtil::HumanString(smaller_shape),
870           ShapeUtil::HumanString(larger_shape));
871     }
872     if (small_is_dynamic != large_is_dynamic) {
873       if (small_dimension_size == large_dimension_size ||
874           (small_dimension_size == 1 && !small_is_dynamic) ||
875           (large_dimension_size == 1 && !large_is_dynamic)) {
876         // Do nothing. It's OK when the size-1 dimension is not static.
877       } else {
878         return InvalidArgument(
879             "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
880             ShapeUtil::HumanString(smaller_shape),
881             ShapeUtil::HumanString(larger_shape));
882       }
883     }
884     // Make sure the broadcast dimensions are listed in a strictly increasing
885     // order.
886     if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
887       return InvalidArgument(
888           "Broadcast dimensions order is wrong: %d comes after %d.",
889           dimension_to_match, broadcast_dimensions.at(i - 1));
890     }
891 
892     output_shape.set_dimensions(dimension_to_match, small_dimension_size);
893     output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
894   }
895 
896   return output_shape;
897 }
898 
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64_t> broadcast_dimensions)899 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
900     HloOpcode operation, const Shape& lhs, const Shape& rhs,
901     absl::Span<const int64_t> broadcast_dimensions) {
902   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
903   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
904 
905   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
906     return InvalidArgument(
907         "Binary op %s with different element types: %s and %s.",
908         HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
909         ShapeUtil::HumanString(rhs));
910   }
911 
912   if (lhs.rank() == rhs.rank()) {
913     std::vector<int64_t> identity_dims(lhs.rank());
914     std::iota(identity_dims.begin(), identity_dims.end(), 0);
915     if (!broadcast_dimensions.empty() &&
916         broadcast_dimensions != identity_dims) {
917       return InvalidArgument(
918           "Broadcast dimensions field must either be not set or be the "
919           "identity on binary operations with operands of the same rank.");
920     }
921   }
922 
923   if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
924     // If the shapes are the same other than layout, the output shape is the
925     // same (elementwise op).
926     Shape result = ShapeUtil::ChangeElementType(
927         lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
928 
929     for (int64_t i = 0; i < rhs.rank(); ++i) {
930       if (rhs.is_dynamic_dimension(i)) {
931         result.set_dynamic_dimension(i, true);
932       }
933     }
934 
935     return result;
936 
937   } else if (lhs.rank() == rhs.rank()) {
938     return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
939   } else {
940     // Ranks do not match, so perform InDim broadcasting using
941     // broadcast_dimensions. Scalar broadcasting is a special case of this.
942     const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
943     const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
944 
945     // After InDim broadcasting, perform degenerate dimensions broadcasting.
946     TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
947                         InferInDimBroadcastShape(smaller_shape, larger_shape,
948                                                  broadcast_dimensions));
949 
950     return InferDegenerateDimensionBroadcastShape(
951         operation, indim_broadcast_shape, larger_shape);
952   }
953 }
954 
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)955 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
956     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
957   return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
958                             /*broadcast_dimensions=*/{});
959 }
960 
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64_t> broadcast_dimensions)961 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
962     HloOpcode opcode, const Shape& lhs, const Shape& rhs,
963     absl::Span<const int64_t> broadcast_dimensions) {
964   VLOG(2) << StrFormat(
965       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
966       HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
967       ShapeUtil::HumanStringWithLayout(rhs),
968       StrJoin(broadcast_dimensions, ", "));
969 
970   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
971   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
972 
973   TF_RETURN_IF_ERROR(ExpectArray(
974       lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
975   TF_RETURN_IF_ERROR(ExpectArray(
976       rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
977   switch (opcode) {
978     case HloOpcode::kAdd:
979     case HloOpcode::kMaximum:
980     case HloOpcode::kMinimum:
981     case HloOpcode::kMultiply:
982       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
983                                            broadcast_dimensions);
984 
985     case HloOpcode::kSubtract:
986     case HloOpcode::kAtan2:
987     case HloOpcode::kPower:
988     case HloOpcode::kDivide:
989     case HloOpcode::kRemainder:
990     case HloOpcode::kShiftLeft:
991     case HloOpcode::kShiftRightArithmetic:
992     case HloOpcode::kShiftRightLogical:
993       if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
994         return InvalidArgument(
995             "Expected element type in shape to be arithmetic type for "
996             "operation %s; got PRED.",
997             HloOpcodeString(opcode));
998       }
999       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1000                                            broadcast_dimensions);
1001 
1002     case HloOpcode::kComplex: {
1003       if (!ShapeUtil::ElementIsFloating(lhs)) {
1004         return InvalidArgument(
1005             "Expected element type in shape to be floating for complex compose "
1006             "operation; got %s.",
1007             PrimitiveType_Name(lhs.element_type()));
1008       }
1009       TF_ASSIGN_OR_RETURN(const Shape& shape,
1010                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1011                                                         broadcast_dimensions));
1012       if (lhs.element_type() == F32 && rhs.element_type() == F32) {
1013         return ShapeUtil::ChangeElementType(shape, C64);
1014       } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
1015         return ShapeUtil::ChangeElementType(shape, C128);
1016       } else {
1017         return Unimplemented("Complex component type is not implemented.");
1018       }
1019     }
1020     case HloOpcode::kAnd:
1021     case HloOpcode::kOr:
1022     case HloOpcode::kXor:
1023       if (lhs.element_type() != PRED &&
1024           !primitive_util::IsIntegralType(lhs.element_type())) {
1025         return InvalidArgument(
1026             "Expected pred or integral type in argument to and/or operation; "
1027             "got %s.",
1028             PrimitiveType_Name(lhs.element_type()));
1029       }
1030       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1031                                            broadcast_dimensions);
1032     case HloOpcode::kCompare: {
1033       TF_ASSIGN_OR_RETURN(const Shape& shape,
1034                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1035                                                         broadcast_dimensions));
1036       return ShapeUtil::ChangeElementType(shape, PRED);
1037     }
1038     default:
1039       return Unimplemented(
1040           "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1041           HloOpcodeString(opcode), lhs.ShortDebugString(),
1042           rhs.ShortDebugString());
1043   }
1044 }
1045 
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1046 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1047     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1048     const HloInstruction* ehs) {
1049   return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1050 }
1051 
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1052 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1053     HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1054   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1055   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1056   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1057   switch (opcode) {
1058     case HloOpcode::kClamp:
1059       return InferClampShape(lhs, rhs, ehs);
1060     case HloOpcode::kSelect:
1061       return InferSelectShape(lhs, rhs, ehs);
1062     default:
1063       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1064   }
1065 }
1066 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1067 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1068     HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1069   std::vector<const Shape*> operand_shapes;
1070   operand_shapes.reserve(operands.size());
1071   for (const HloInstruction* operand : operands) {
1072     operand_shapes.push_back(&operand->shape());
1073   }
1074   return InferVariadicOpShape(opcode, operand_shapes);
1075 }
1076 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1077 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1078     HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1079   for (const Shape* shape : operand_shapes) {
1080     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1081   }
1082   switch (opcode) {
1083     case HloOpcode::kTuple: {
1084       Shape result = ShapeUtil::MakeTupleShape({});
1085       result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1086       for (const Shape* shape : operand_shapes) {
1087         ShapeUtil::AppendShapeToTuple(*shape, &result);
1088       }
1089       return result;
1090     }
1091     case HloOpcode::kSort: {
1092       if (operand_shapes.size() == 1) {
1093         return *operand_shapes[0];
1094       } else {
1095         for (int64_t operand = 1; operand < operand_shapes.size(); ++operand) {
1096           if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1097                                          *operand_shapes[operand])) {
1098             return InvalidArgument(
1099                 "Sort keys and values dimensions must match. "
1100                 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1101                 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1102                 ShapeUtil::HumanString(*operand_shapes[operand]));
1103           }
1104         }
1105         return ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
1106       }
1107       return InvalidArgument("Unexpected number of operands for sort");
1108     }
1109     default:
1110       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1111   }
1112 }
1113 
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64_t> dimensions)1114 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1115     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1116     absl::Span<const int64_t> dimensions) {
1117   if (arg_shapes.empty()) {
1118     return InvalidArgument("Map expects at least one argument.");
1119   }
1120 
1121   // All arguments must have the same shape ignoring the element types.
1122   const Shape* arg_shape = arg_shapes[0];
1123   for (size_t i = 1; i < arg_shapes.size(); ++i) {
1124     TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1125 
1126     if (ShapeUtil::CompatibleIgnoringElementType(*arg_shapes[i], *arg_shape)) {
1127       continue;
1128     }
1129     if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1130                                                       *arg_shape)) {
1131       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1132         continue;
1133       }
1134       if (ShapeUtil::IsScalar(*arg_shape)) {
1135         arg_shape = arg_shapes[i];
1136         continue;
1137       }
1138     }
1139 
1140     std::vector<std::string> pieces;
1141     pieces.reserve(arg_shapes.size());
1142     for (const Shape* shape : arg_shapes) {
1143       pieces.push_back(ShapeUtil::HumanString(*shape));
1144     }
1145     return InvalidArgument(
1146         "Map operation requires all operands to have the same shape; got: "
1147         "%s.",
1148         StrJoin(pieces, ", "));
1149   }
1150 
1151   // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1152   // only support mapping across all dimensions: i.e. scalar map functions).
1153   if (dimensions.size() != arg_shape->dimensions_size()) {
1154     return InvalidArgument(
1155         "Map applied to a subset of dimensions currently not supported: "
1156         "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1157         arg_shape->dimensions_size(), dimensions.size());
1158   }
1159 
1160   // Check that requested map dimensions numbers are monotonically increasing.
1161   for (int i = 0; i < dimensions.size(); ++i) {
1162     if (dimensions[i] != i) {
1163       return InvalidArgument(
1164           "Map requires monotonically increasing dimension numbers; got: %s.",
1165           StrJoin(dimensions, ", "));
1166     }
1167   }
1168 
1169   // The applied function's arity equals the number of arguments.
1170   if (arg_shapes.size() != to_apply.parameters_size()) {
1171     return InvalidArgument(
1172         "Map applied function arity must match number of arguments; got: "
1173         "arity: %d, arguments: %u.",
1174         to_apply.parameters_size(), arg_shapes.size());
1175   }
1176 
1177   // The parameters should all be scalars, and the output too.
1178   const Shape& output_shape = to_apply.result();
1179   if (!ShapeUtil::IsScalar(output_shape)) {
1180     return InvalidArgument(
1181         "Mapped computation's result has to be a scalar; got: %s.",
1182         ShapeUtil::HumanString(output_shape));
1183   }
1184 
1185   for (int i = 0; i < to_apply.parameters_size(); ++i) {
1186     const Shape& parameter_shape = to_apply.parameters(i);
1187 
1188     if (!ShapeUtil::IsScalar(parameter_shape)) {
1189       return InvalidArgument(
1190           "Mapped computation's parameter has to be a scalar; "
1191           "got parameter %d shape: %s.",
1192           i, ShapeUtil::HumanString(parameter_shape));
1193     }
1194 
1195     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1196                                                        *arg_shapes[i])) {
1197       return InvalidArgument(
1198           "Mapped computation's parameter type has to match argument element "
1199           "type; got parameter %d shape: %s, argument shape: %s.",
1200           i, ShapeUtil::HumanString(parameter_shape),
1201           ShapeUtil::HumanString(*arg_shape));
1202     }
1203   }
1204 
1205   return ShapeUtil::MakeShape(output_shape.element_type(),
1206                               arg_shape->dimensions());
1207 }
1208 
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64_t feature_index)1209 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1210     const Shape& operand_shape, const Shape& scale_shape,
1211     const Shape& offset_shape, int64_t feature_index) {
1212   TF_RETURN_IF_ERROR(
1213       ExpectArray(operand_shape, "operand of batch norm training"));
1214   TF_RETURN_IF_ERROR(
1215       ExpectArray(offset_shape, "offset input of batch norm training"));
1216   TF_RETURN_IF_ERROR(
1217       ExpectArray(scale_shape, "scale input of batch norm training"));
1218 
1219   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1220                OkStatus());
1221   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1222                OkStatus());
1223   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1224                OkStatus());
1225 
1226   if (feature_index >= operand_shape.rank()) {
1227     return InvalidArgument(
1228         "Expected feature_index of batch-norm-training to be "
1229         "smaller than the rank of operand_shape; "
1230         "got feature_index %d, and rank %d.",
1231         feature_index, operand_shape.rank());
1232   }
1233 
1234   if (feature_index < 0) {
1235     return InvalidArgument(
1236         "Expected feature_index of batch-norm-training to "
1237         "be a non-negative number, got %d.",
1238         feature_index);
1239   }
1240 
1241   if (operand_shape.rank() < 1) {
1242     return InvalidArgument(
1243         "Expected the rank of operand to "
1244         "batch-norm-training to be at least 1; got %d.",
1245         operand_shape.rank());
1246   }
1247 
1248   if (offset_shape.rank() != 1) {
1249     return InvalidArgument(
1250         "Offset input of batch-norm-training must have"
1251         " rank 1, but has rank %d.",
1252         offset_shape.rank());
1253   }
1254 
1255   if (scale_shape.rank() != 1) {
1256     return InvalidArgument(
1257         "Scale input of batch-norm-training must have"
1258         " rank 1, but has rank %d.",
1259         scale_shape.rank());
1260   }
1261 
1262   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1263     return InvalidArgument(
1264         "The operand to batch-norm-training must have a floating point "
1265         "element type, but the shape is %s.",
1266         PrimitiveType_Name(operand_shape.element_type()));
1267   }
1268 
1269   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1270                                                      operand_shape)) {
1271     return InvalidArgument(
1272         "The inputs should have the same element type for batch-norm-training, "
1273         "but the shape of offset factor is %s "
1274         "and the shape of operand is %s.",
1275         PrimitiveType_Name(offset_shape.element_type()),
1276         PrimitiveType_Name(operand_shape.element_type()));
1277   }
1278 
1279   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1280                                                      operand_shape)) {
1281     return InvalidArgument(
1282         "The inputs should have the same element type for batch-norm-training, "
1283         "but the shape of scale factor is %s "
1284         "and the shape of operand is %s.",
1285         PrimitiveType_Name(scale_shape.element_type()),
1286         PrimitiveType_Name(operand_shape.element_type()));
1287   }
1288 
1289   const int64_t feature_count = operand_shape.dimensions(feature_index);
1290   Shape output_shape_for_mean_and_var =
1291       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1292 
1293   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1294     return InvalidArgument(
1295         "The size of offset factor should be the same as feature count,"
1296         "but the size of offset factor is %d "
1297         "and the feature count is %d.",
1298         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1299   }
1300 
1301   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1302     return InvalidArgument(
1303         "The size of scale factor should be the same as feature count,"
1304         "but the size of scale factor is %d "
1305         "and the feature count is %d.",
1306         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1307   }
1308 
1309   return ShapeUtil::MakeTupleShapeWithPtrs({&operand_shape,
1310                                             &output_shape_for_mean_and_var,
1311                                             &output_shape_for_mean_and_var});
1312 }
1313 
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64_t feature_index)1314 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1315     const Shape& operand_shape, const Shape& scale_shape,
1316     const Shape& offset_shape, const Shape& mean_shape,
1317     const Shape& variance_shape, int64_t feature_index) {
1318   TF_RETURN_IF_ERROR(
1319       ExpectArray(operand_shape, "operand of batch norm inference"));
1320   TF_RETURN_IF_ERROR(
1321       ExpectArray(offset_shape, "offset input of batch norm inference"));
1322   TF_RETURN_IF_ERROR(
1323       ExpectArray(scale_shape, "scale input of batch norm inference"));
1324 
1325   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1326   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1327   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1328   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1329   TF_RETURN_IF_ERROR(
1330       ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1331 
1332   if (feature_index >= operand_shape.rank()) {
1333     return InvalidArgument(
1334         "Expected feature_index of batch-norm-inference to be "
1335         "smaller than the rank of operand_shape; "
1336         "got feature_index %d, and rank %d.",
1337         feature_index, operand_shape.rank());
1338   }
1339 
1340   if (feature_index < 0) {
1341     return InvalidArgument(
1342         "Expected feature_index of batch-norm-inference to "
1343         "be a non-negative number, got %d.",
1344         feature_index);
1345   }
1346 
1347   if (operand_shape.rank() < 1) {
1348     return InvalidArgument(
1349         "Expected the rank of operand to "
1350         "batch-norm-inference to be at least 1; got %d.",
1351         operand_shape.rank());
1352   }
1353 
1354   if (offset_shape.rank() != 1) {
1355     return InvalidArgument(
1356         "Offset input of batch-norm-inference must have"
1357         " rank 1, but has rank %d.",
1358         offset_shape.rank());
1359   }
1360 
1361   if (scale_shape.rank() != 1) {
1362     return InvalidArgument(
1363         "Scale input of batch-norm-inference must have"
1364         " rank 1, but has rank %d.",
1365         scale_shape.rank());
1366   }
1367 
1368   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1369     return InvalidArgument(
1370         "The operand to batch-norm-inference must have a floating point "
1371         "element type, but the shape is %s.",
1372         PrimitiveType_Name(operand_shape.element_type()));
1373   }
1374 
1375   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1376                                                      operand_shape)) {
1377     return InvalidArgument(
1378         "The inputs should have the same element type for "
1379         "batch-norm-inference, "
1380         "but the shape of offset factor is %s "
1381         "and the shape of operand is %s.",
1382         PrimitiveType_Name(offset_shape.element_type()),
1383         PrimitiveType_Name(operand_shape.element_type()));
1384   }
1385 
1386   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1387                                                      operand_shape)) {
1388     return InvalidArgument(
1389         "The inputs should have the same element type for "
1390         "batch-norm-inference, "
1391         "but the shape of scale factor is %s "
1392         "and the shape of operand is %s.",
1393         PrimitiveType_Name(scale_shape.element_type()),
1394         PrimitiveType_Name(operand_shape.element_type()));
1395   }
1396 
1397   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1398                                                      operand_shape)) {
1399     return InvalidArgument(
1400         "The inputs should have the same element type for "
1401         "batch-norm-inference, "
1402         "but the shape of mean is %s "
1403         "and the shape of operand is %s.",
1404         PrimitiveType_Name(mean_shape.element_type()),
1405         PrimitiveType_Name(operand_shape.element_type()));
1406   }
1407 
1408   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1409                                                      operand_shape)) {
1410     return InvalidArgument(
1411         "The inputs should have the same element type for "
1412         "batch-norm-inference, "
1413         "but the shape of variance is %s "
1414         "and the shape of operand is %s.",
1415         PrimitiveType_Name(mean_shape.element_type()),
1416         PrimitiveType_Name(variance_shape.element_type()));
1417   }
1418 
1419   const int64_t feature_count = operand_shape.dimensions(feature_index);
1420   Shape output_shape_for_mean_and_var =
1421       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1422 
1423   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1424     return InvalidArgument(
1425         "The size of offset factor should be the same as feature count,"
1426         "but the size of offset factor is %d "
1427         "and the feature count is %d.",
1428         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1429   }
1430 
1431   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1432     return InvalidArgument(
1433         "The size of scale factor should be the same as feature count,"
1434         "but the size of scale factor is %d "
1435         "and the feature count is %d.",
1436         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1437   }
1438 
1439   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1440     return InvalidArgument(
1441         "The size of mean should be the same as feature count,"
1442         "but the size of mean is %d "
1443         "and the feature count is %d.",
1444         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1445   }
1446 
1447   if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1448     return InvalidArgument(
1449         "The size of variance should be the same as feature count,"
1450         "but the size of variance is %d "
1451         "and the feature count is %d.",
1452         ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1453   }
1454 
1455   return operand_shape;
1456 }
1457 
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64_t feature_index)1458 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1459     const Shape& operand_shape, const Shape& scale_shape,
1460     const Shape& mean_shape, const Shape& var_shape,
1461     const Shape& output_grad_shape, int64_t feature_index) {
1462   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1463   TF_RETURN_IF_ERROR(
1464       ExpectArray(scale_shape, "scale input of batch norm grad"));
1465   TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1466   TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1467   TF_RETURN_IF_ERROR(
1468       ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1469 
1470   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1471   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1472   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1473   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1474   TF_RETURN_IF_ERROR(
1475       ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1476 
1477   if (feature_index >= operand_shape.rank()) {
1478     return InvalidArgument(
1479         "Expected feature_index of batch-norm-grad to be "
1480         "smaller than the rank of operand_shape; "
1481         "got feature_index %d, and rank %d.",
1482         feature_index, operand_shape.rank());
1483   }
1484 
1485   if (operand_shape.rank() != output_grad_shape.rank()) {
1486     return InvalidArgument(
1487         "Expected operand_shape of batch-norm-grad to have the same rank as"
1488         " output_grad_shape; got rank(oprand_shape) %d, and"
1489         " rank(output_grad_shape) %d.",
1490         operand_shape.rank(), output_grad_shape.rank());
1491   }
1492 
1493   if (mean_shape.rank() != 1) {
1494     return InvalidArgument(
1495         "Mean input of batch-norm-grad must have"
1496         " rank 1, but has rank %d.",
1497         mean_shape.rank());
1498   }
1499 
1500   if (scale_shape.rank() != 1) {
1501     return InvalidArgument(
1502         "Scale input of batch-norm-grad must have"
1503         " rank 1, but has rank %d.",
1504         scale_shape.rank());
1505   }
1506 
1507   if (var_shape.rank() != 1) {
1508     return InvalidArgument(
1509         "Var input of batch-norm-grad must have"
1510         " rank 1, but has rank %d.",
1511         var_shape.rank());
1512   }
1513 
1514   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1515     return InvalidArgument(
1516         "The operand to batch-norm-grad must have a floating point "
1517         "element type, but the shape is %s.",
1518         PrimitiveType_Name(operand_shape.element_type()));
1519   }
1520 
1521   if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1522     return InvalidArgument(
1523         "The output_grad to batch-norm-grad must have a floating point "
1524         "element type, but the shape is %s.",
1525         PrimitiveType_Name(output_grad_shape.element_type()));
1526   }
1527 
1528   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1529                                                      operand_shape)) {
1530     return InvalidArgument(
1531         "The inputs should have the same element type for batch-norm-grad, "
1532         "but the element type of output_grad is %s "
1533         "and the element type of operand is %s.",
1534         PrimitiveType_Name(output_grad_shape.element_type()),
1535         PrimitiveType_Name(operand_shape.element_type()));
1536   }
1537 
1538   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1539                                                      operand_shape)) {
1540     return InvalidArgument(
1541         "The inputs should have the same element type for batch-norm-grad, "
1542         "but the element type of scale factor is %s "
1543         "and the element type of operand is %s.",
1544         PrimitiveType_Name(scale_shape.element_type()),
1545         PrimitiveType_Name(operand_shape.element_type()));
1546   }
1547 
1548   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1549                                                      operand_shape)) {
1550     return InvalidArgument(
1551         "The inputs should have the same element type for batch-norm-grad, "
1552         "but the element type of mean is %s "
1553         "and the element type of operand is %s.",
1554         PrimitiveType_Name(mean_shape.element_type()),
1555         PrimitiveType_Name(operand_shape.element_type()));
1556   }
1557 
1558   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1559                                                      operand_shape)) {
1560     return InvalidArgument(
1561         "The inputs should have the same element type for batch-norm-grad, "
1562         "but the element type of mean is %s "
1563         "and the element type of operand is %s.",
1564         PrimitiveType_Name(mean_shape.element_type()),
1565         PrimitiveType_Name(operand_shape.element_type()));
1566   }
1567 
1568   const int64_t feature_count = operand_shape.dimensions(feature_index);
1569 
1570   Shape feature_shape =
1571       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1572 
1573   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1574     return InvalidArgument(
1575         "The size of mean should be the same as feature count,"
1576         "but the size of offset factor is %d "
1577         "and the feature count is %d.",
1578         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1579   }
1580 
1581   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1582     return InvalidArgument(
1583         "The size of scale factor should be the same as feature count,"
1584         "but the size of scale factor is %d "
1585         "and the feature count is %d.",
1586         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1587   }
1588 
1589   if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1590     return InvalidArgument(
1591         "The size of variance should be the same as feature count,"
1592         "but the size of variance is %d "
1593         "and the feature count is %d.",
1594         ShapeUtil::GetDimension(var_shape, 0), feature_count);
1595   }
1596 
1597   // Verify operand_shape and output_grad_shape have same bounds.
1598   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
1599     if (ShapeUtil::GetDimension(operand_shape, i) !=
1600         ShapeUtil::GetDimension(output_grad_shape, i)) {
1601       return InvalidArgument(
1602           "The bounds of operand shape should be the same as output_grad's,"
1603           "but the bound of operand_shape at dimension %d is %d "
1604           "and the bound of output_grad_shape is %d.",
1605           i, ShapeUtil::GetDimension(operand_shape, i),
1606           ShapeUtil::GetDimension(output_grad_shape, i));
1607     }
1608   }
1609 
1610   return ShapeUtil::MakeTupleShapeWithPtrs(
1611       {&operand_shape, &feature_shape, &feature_shape});
1612 }
1613 
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums,std::optional<PrimitiveType> preferred_element_type)1614 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1615     const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
1616     int64_t batch_group_count, const Window& window,
1617     const ConvolutionDimensionNumbers& dnums,
1618     std::optional<PrimitiveType> preferred_element_type) {
1619   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1620   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1621 
1622   if (feature_group_count <= 0) {
1623     return InvalidArgument(
1624         "feature_group_count must be a positive number, got %d",
1625         feature_group_count);
1626   }
1627 
1628   if (batch_group_count <= 0) {
1629     return InvalidArgument(
1630         "batch_group_count must be a positive number, got %d",
1631         batch_group_count);
1632   }
1633 
1634   if (batch_group_count > 1 && feature_group_count > 1) {
1635     return InvalidArgument(
1636         "both batch_group_count %d and feature_group_count %d cannot be "
1637         "greater than 1",
1638         batch_group_count, feature_group_count);
1639   }
1640 
1641   if (dnums.input_spatial_dimensions_size() !=
1642       dnums.kernel_spatial_dimensions_size()) {
1643     return InvalidArgument(
1644         "Both arguments to convolution must have same number of dimensions.\n"
1645         "Numbers: %s",
1646         dnums.DebugString());
1647   }
1648 
1649   if (dnums.input_spatial_dimensions_size() !=
1650       dnums.output_spatial_dimensions_size()) {
1651     return InvalidArgument(
1652         "Both input and output of convolution must have same number of "
1653         "dimensions.\nNumbers: %s",
1654         dnums.DebugString());
1655   }
1656 
1657   const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1658   if (window.dimensions_size() != num_spatial_dims) {
1659     return InvalidArgument(
1660         "Window must have same number of dimensions as dimension numbers.\n"
1661         "Window: %s\nDimension numbers: %s.",
1662         window.DebugString(), dnums.DebugString());
1663   }
1664 
1665   const int num_dims = num_spatial_dims + 2;
1666   if (lhs.rank() != num_dims) {
1667     return InvalidArgument(
1668         "The LHS argument to a convolution should have rank %d; lhs: %s.",
1669         num_dims, ShapeUtil::HumanString(lhs));
1670   }
1671   if (rhs.rank() != num_dims) {
1672     return InvalidArgument(
1673         "The RHS argument to a convolution should have rank %d; rhs: %s.",
1674         num_dims, ShapeUtil::HumanString(rhs));
1675   }
1676   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1677   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1678 
1679   // Verifies that the input and window dimensions are a permutation of
1680   // the dimension numbers.
1681   std::vector<int64_t> input_dnums(num_dims);
1682   input_dnums[0] = dnums.input_batch_dimension();
1683   input_dnums[1] = dnums.input_feature_dimension();
1684   std::copy(dnums.input_spatial_dimensions().begin(),
1685             dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1686   absl::c_sort(input_dnums);
1687 
1688   std::vector<int64_t> window_dnums(num_dims);
1689   window_dnums[0] = dnums.kernel_input_feature_dimension();
1690   window_dnums[1] = dnums.kernel_output_feature_dimension();
1691   std::copy(dnums.kernel_spatial_dimensions().begin(),
1692             dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1693   absl::c_sort(window_dnums);
1694 
1695   std::vector<int64_t> output_dnums(num_dims);
1696   output_dnums[0] = dnums.output_batch_dimension();
1697   output_dnums[1] = dnums.output_feature_dimension();
1698   std::copy(dnums.output_spatial_dimensions().begin(),
1699             dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1700   absl::c_sort(output_dnums);
1701 
1702   std::vector<int64_t> expected_dnums(num_dims);
1703   std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1704 
1705   const auto in_range = [num_dims](int64_t i) {
1706     return 0 <= i && i < num_dims;
1707   };
1708   if (!absl::c_all_of(input_dnums, in_range) ||
1709       !absl::c_all_of(window_dnums, in_range) ||
1710       !absl::c_all_of(output_dnums, in_range)) {
1711     return InvalidArgument(
1712         "A dimension number is out of range in convolution: %s.",
1713         dnums.DebugString());
1714   }
1715 
1716   if (input_dnums != expected_dnums) {
1717     return InvalidArgument(
1718         "Input dimensions of convolution must contain each dimension exactly "
1719         "once: %s.",
1720         dnums.DebugString());
1721   }
1722   if (window_dnums != expected_dnums) {
1723     return InvalidArgument(
1724         "Window dimensions of convolution must contain each dimension exactly "
1725         "once: %s.",
1726         dnums.DebugString());
1727   }
1728   if (output_dnums != expected_dnums) {
1729     return InvalidArgument(
1730         "Output dimensions of convolution must contain each dimension exactly "
1731         "once: %s.",
1732         dnums.DebugString());
1733   }
1734 
1735   std::vector<int64_t> input_spatial_dims(num_spatial_dims);
1736   for (int i = 0; i < num_spatial_dims; ++i) {
1737     input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1738   }
1739   const int64_t input_features =
1740       lhs.dimensions(dnums.input_feature_dimension());
1741   const int64_t input_batch = lhs.dimensions(dnums.input_batch_dimension());
1742 
1743   std::vector<int64_t> kernel_spatial_dims(num_spatial_dims);
1744   for (int i = 0; i < num_spatial_dims; ++i) {
1745     kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1746   }
1747   const int64_t kernel_input_features =
1748       rhs.dimensions(dnums.kernel_input_feature_dimension());
1749   const int64_t kernel_output_features =
1750       rhs.dimensions(dnums.kernel_output_feature_dimension());
1751 
1752   if (kernel_output_features % batch_group_count != 0) {
1753     return InvalidArgument(
1754         "Expected output feature dimension size (value %d) to be a multiple of "
1755         "batch group count %d; got <conv>(%s, %s)\n"
1756         "Dimension numbers: {%s}.",
1757         kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1758         ShapeUtil::HumanString(rhs), dnums.DebugString());
1759   }
1760 
1761   if (input_features % feature_group_count != 0 ||
1762       input_features / feature_group_count != kernel_input_features) {
1763     return InvalidArgument(
1764         "Expected LHS feature dimension (value %d) to be a multiple of "
1765         "feature_group_count (value %d), and LHS feature dimension / "
1766         "feature_group_count = RHS feature dimension (value %d); "
1767         "got <conv>(%s, %s)\n"
1768         "Dimension numbers: {%s}.",
1769         input_features, feature_group_count, kernel_input_features,
1770         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1771         dnums.DebugString());
1772   }
1773 
1774   if (kernel_output_features % feature_group_count > 0) {
1775     // A depthwise/grouped filter has the shape
1776     // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1777     // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1778     // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1779     // feature count (which is equal to kernel output features) has to be a
1780     // multiple of feature_group_count.
1781     return InvalidArgument(
1782         "Expected output feature dimension (value %d) to be divisible by "
1783         "feature_group_count (value %d); "
1784         "got <conv>(%s, %s)\n"
1785         "Dimension numbers: {%s}.",
1786         kernel_output_features, feature_group_count,
1787         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1788         dnums.DebugString());
1789   }
1790 
1791   if (input_batch % batch_group_count != 0) {
1792     return InvalidArgument(
1793         "Expected input batch dimension (value %d) to be divisible by "
1794         "batch_group_count (value %d); "
1795         "got <conv>(%s, %s)\n"
1796         "Dimension numbers: {%s}.",
1797         input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1798         ShapeUtil::HumanString(rhs), dnums.DebugString());
1799   }
1800 
1801   std::vector<int64_t> window_dims(num_spatial_dims);
1802   for (int i = 0; i < num_spatial_dims; ++i) {
1803     window_dims[i] = window.dimensions(i).size();
1804   }
1805   if (kernel_spatial_dims != window_dims) {
1806     return InvalidArgument(
1807         "Window dimensions do not match RHS shape:\n\t"
1808         "RHS shape: %s\n\t"
1809         "Window: {%s}\n\t"
1810         "Dimension numbers: {%s}.",
1811         ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1812         dnums.ShortDebugString());
1813   }
1814 
1815   Shape base_shape =
1816       ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1817   TF_ASSIGN_OR_RETURN(
1818       Shape window_output_shape,
1819       InferWindowOutputShape(base_shape, window, lhs.element_type()));
1820 
1821   std::vector<int64_t> dimensions(num_dims);
1822   dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1823   dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1824 
1825   for (int i = 0; i < num_spatial_dims; ++i) {
1826     dimensions[dnums.output_spatial_dimensions(i)] =
1827         window_output_shape.dimensions(i);
1828   }
1829   std::vector<bool> is_dynamic(num_dims);
1830   for (int i = 0; i < num_dims; i++) {
1831     if (lhs.is_dynamic_dimension(i)) {
1832       if (i == dnums.input_batch_dimension()) {
1833         is_dynamic[dnums.output_batch_dimension()] = true;
1834       } else if (i == dnums.input_feature_dimension()) {
1835         // Input feature dimension is a contracting dimension, which does not
1836         // affect the output dimension size. So we need to do nothing.
1837       } else {
1838         for (int64_t j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
1839           if (i == dnums.input_spatial_dimensions(j)) {
1840             // i is a spatial dimension, find corresponding output spatial
1841             // dimension.
1842             is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1843           }
1844         }
1845       }
1846     }
1847     if (rhs.is_dynamic_dimension(i)) {
1848       if (i == dnums.kernel_input_feature_dimension()) {
1849         // Kernel feature dimension does not affect the output dimension size.
1850         // So we need to do nothing.
1851       } else if (i == dnums.kernel_output_feature_dimension()) {
1852         return InvalidArgument(
1853             "Dynamic output feature dim on convolution kernel is not "
1854             "supported: rhs shape is %s ",
1855             rhs.ToString());
1856       } else {
1857         for (int64_t j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
1858           if (i == dnums.kernel_spatial_dimensions(j)) {
1859             // i is a spatial dimension, find corresponding output spatial
1860             // dimension.
1861             is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1862           }
1863         }
1864       }
1865     }
1866   }
1867   TF_ASSIGN_OR_RETURN(
1868       PrimitiveType type,
1869       MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1870                   preferred_element_type));
1871   return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
1872 }
1873 
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64_t> fft_length)1874 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1875     const Shape& in, const FftType fft_type,
1876     const absl::Span<const int64_t> fft_length) {
1877   const int64_t fft_rank = fft_length.size();
1878   if (fft_rank < 1 || fft_rank > 3) {
1879     return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1880   }
1881 #define RET_CHECK_RANK(x)                            \
1882   if (x.dimensions_size() < fft_rank) {              \
1883     return InvalidArgument(                          \
1884         "FFT of rank %d requires input of at least " \
1885         "same rank; got input of rank %d",           \
1886         fft_rank, x.dimensions_size());              \
1887   }
1888   switch (fft_type) {
1889     case FFT:
1890     case IFFT:
1891       if (!primitive_util::IsComplexType(in.element_type())) {
1892         return InvalidArgument("%s requires complex input type, found %s.",
1893                                FftType_Name(fft_type),
1894                                PrimitiveType_Name(in.element_type()));
1895       }
1896       RET_CHECK_RANK(in);
1897       return in;
1898     case RFFT: {
1899       if (in.element_type() != F32 && in.element_type() != F64) {
1900         return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
1901                                PrimitiveType_Name(in.element_type()));
1902       }
1903       RET_CHECK_RANK(in);
1904       for (int i = 0; i < fft_rank; i++) {
1905         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1906             fft_length[i]) {
1907           return InvalidArgument(
1908               "RFFT requires innermost dimensions match fft_length but "
1909               "dimension %d is %d and should be %d.",
1910               in.dimensions_size() - fft_rank + i,
1911               in.dimensions(in.dimensions_size() - fft_rank + i),
1912               fft_length[i]);
1913         }
1914       }
1915       Shape result = ShapeUtil::ChangeElementType(
1916           in, in.element_type() == F32 ? C64 : C128);
1917       // Preserve the size of zero-sized dimensions.
1918       if (fft_length[fft_rank - 1] != 0) {
1919         result.set_dimensions(result.dimensions_size() - 1,
1920                               fft_length[fft_rank - 1] / 2 + 1);
1921       }
1922       return result;
1923     }
1924     case IRFFT: {
1925       if (!primitive_util::IsComplexType(in.element_type())) {
1926         return InvalidArgument("IRFFT requires complex input type, found %s.",
1927                                PrimitiveType_Name(in.element_type()));
1928       }
1929       RET_CHECK_RANK(in);
1930       Shape result = ShapeUtil::ComplexComponentShape(in);
1931       for (int i = 0; i < fft_rank - 1; i++) {
1932         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1933             fft_length[i]) {
1934           return InvalidArgument(
1935               "IRFFT requires all but one innermost dimensions match "
1936               "fft_length, but dimension %d is %d and should be %d.",
1937               in.dimensions_size() - fft_rank + i,
1938               in.dimensions(in.dimensions_size() - fft_rank + i),
1939               fft_length[i]);
1940         }
1941       }
1942       // The size of zero-sized dimensions is preserved.
1943       if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1944            fft_length[fft_rank - 1] != 0) &&
1945           in.dimensions(in.dimensions_size() - 1) !=
1946               fft_length[fft_rank - 1] / 2 + 1) {
1947         return InvalidArgument(
1948             "IRFFT requires innermost dimension matches fft_length/2+1, but "
1949             "dimension %d is %d and should be %d.",
1950             in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1951             fft_length[fft_rank - 1] / 2 + 1);
1952       }
1953       result.set_dimensions(result.dimensions_size() - 1,
1954                             fft_length[fft_rank - 1]);
1955       return result;
1956     }
1957     default:
1958       LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1959   }
1960 #undef RET_CHECK_RANK
1961 }
1962 
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1963 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1964     const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1965   if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1966       a.element_type() != b.element_type()) {
1967     return InvalidArgument(
1968         "Expected element types in shape to be floating or complex and "
1969         "identical for TriangularSolve; got %s and %s.",
1970         PrimitiveType_Name(a.element_type()),
1971         PrimitiveType_Name(b.element_type()));
1972   }
1973   if (a.rank() < 2) {
1974     return InvalidArgument(
1975         "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1976         a.ToString());
1977   }
1978   if (b.rank() != a.rank()) {
1979     return InvalidArgument(
1980         "Arguments to triangular solve must have equal rank; got %s and %s.",
1981         b.ToString(), a.ToString());
1982   }
1983   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1984     return InvalidArgument(
1985         "The two minor dimensions of 'a' must have equal size, got %s.",
1986         a.ToString());
1987   }
1988   if (a.dimensions(a.rank() - 1) !=
1989       b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1990     return InvalidArgument(
1991         "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1992         "%s",
1993         a.ToString(), b.ToString());
1994   }
1995   absl::Span<const int64_t> a_batch_dims(a.dimensions());
1996   absl::Span<const int64_t> b_batch_dims(b.dimensions());
1997   a_batch_dims.remove_suffix(2);
1998   b_batch_dims.remove_suffix(2);
1999   if (a_batch_dims != b_batch_dims) {
2000     return InvalidArgument(
2001         "The leading batch dimensions of the arguments to triangular solve "
2002         "must be equal; got %s and %s.",
2003         b.ToString(), a.ToString());
2004   }
2005   if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
2006       options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
2007     return InvalidArgument(
2008         "Invalid transpose option value for triangular solve (%d).\n",
2009         options.transpose_a());
2010   }
2011   return b;
2012 }
2013 
InferCholeskyShape(const Shape & a)2014 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
2015     const Shape& a) {
2016   if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
2017     return InvalidArgument(
2018         "Expected element type in shape to be floating or complex for "
2019         "Cholesky; got %s.",
2020         PrimitiveType_Name(a.element_type()));
2021   }
2022   if (a.rank() < 2) {
2023     return InvalidArgument(
2024         "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
2025         a.ToString());
2026   }
2027   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
2028     return InvalidArgument(
2029         "The two minor dimensions of 'a' must have equal size, got %s.",
2030         a.ToString());
2031   }
2032   return a;
2033 }
2034 
InferAllGatherShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2035 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
2036     absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2037     int64_t shard_count) {
2038   TF_RET_CHECK(all_gather_dimension >= 0);
2039   TF_RET_CHECK(shard_count > 0);
2040 
2041   std::vector<Shape> output_shapes;
2042   output_shapes.reserve(operand_shapes.size());
2043   for (const Shape* operand_shape : operand_shapes) {
2044     TF_RET_CHECK(all_gather_dimension < operand_shape->rank());
2045     TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather"));
2046 
2047     Shape output_shape = *operand_shape;
2048     output_shape.set_dimensions(
2049         all_gather_dimension,
2050         shard_count * output_shape.dimensions(all_gather_dimension));
2051     output_shapes.push_back(output_shape);
2052   }
2053   if (output_shapes.size() == 1) {
2054     return output_shapes[0];
2055   }
2056   return ShapeUtil::MakeTupleShape(output_shapes);
2057 }
2058 
InferAllGatherStartShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2059 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherStartShape(
2060     absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2061     int64_t shard_count) {
2062   TF_ASSIGN_OR_RETURN(
2063       Shape ag_shape,
2064       InferAllGatherShape(operand_shapes, all_gather_dimension, shard_count));
2065   Shape input_shape;
2066   if (operand_shapes.size() == 1) {
2067     input_shape = *operand_shapes[0];
2068   } else {
2069     input_shape = ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
2070   }
2071   return ShapeUtil::MakeTupleShapeWithPtrs({&input_shape, &ag_shape});
2072 }
2073 
InferAllGatherDoneShape(const Shape & all_gather_start_shape)2074 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherDoneShape(
2075     const Shape& all_gather_start_shape) {
2076   return ShapeUtil::GetTupleElementShape(all_gather_start_shape, 1);
2077 }
2078 
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2079 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2080     absl::Span<const Shape* const> operand_shapes) {
2081   for (const Shape* operand_shape : operand_shapes) {
2082     TF_RETURN_IF_ERROR(
2083         ExpectArray(*operand_shape, "operand of cross replica sum"));
2084   }
2085   if (operand_shapes.size() == 1) {
2086     return *operand_shapes[0];
2087   }
2088   return ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
2089 }
2090 
InferReduceScatterShape(absl::Span<const Shape * const> operand_shapes,int64_t scatter_dimension,int64_t shard_count)2091 /* static */ StatusOr<Shape> ShapeInference::InferReduceScatterShape(
2092     absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
2093     int64_t shard_count) {
2094   TF_RET_CHECK(scatter_dimension >= 0);
2095   TF_RET_CHECK(shard_count > 0);
2096 
2097   std::vector<Shape> output_shapes;
2098   output_shapes.reserve(operand_shapes.size());
2099   for (const Shape* operand_shape : operand_shapes) {
2100     TF_RET_CHECK(scatter_dimension < operand_shape->rank());
2101     TF_RETURN_IF_ERROR(
2102         ExpectArray(*operand_shape, "operand of reduce-scatter"));
2103 
2104     int64_t scatter_dim_input_size =
2105         operand_shape->dimensions(scatter_dimension);
2106     if (scatter_dim_input_size % shard_count != 0) {
2107       return InvalidArgument(
2108           "ReduceScatter operand scatter dimension size %d must be "
2109           "dividable by shard_count "
2110           "%d.",
2111           scatter_dim_input_size, shard_count);
2112     }
2113 
2114     Shape output_shape = *operand_shape;
2115     output_shape.set_dimensions(scatter_dimension,
2116                                 scatter_dim_input_size / shard_count);
2117     output_shapes.push_back(output_shape);
2118   }
2119 
2120   if (output_shapes.size() == 1) {
2121     return output_shapes[0];
2122   }
2123   return ShapeUtil::MakeTupleShape(output_shapes);
2124 }
2125 
InferAllReduceStartShape(absl::Span<const Shape * const> operand_shapes)2126 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceStartShape(
2127     absl::Span<const Shape* const> operand_shapes) {
2128   return InferAllReduceShape(operand_shapes);
2129 }
2130 
InferAllReduceDoneShape(const Shape & operand_shape)2131 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceDoneShape(
2132     const Shape& operand_shape) {
2133   // The returned value from AllReduceDone is the operand forwarded.
2134   return operand_shape;
2135 }
2136 
InferAllToAllShape(const Shape & shape,int64_t split_dimension,int64_t concat_dimension,int64_t split_count)2137 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2138     const Shape& shape, int64_t split_dimension, int64_t concat_dimension,
2139     int64_t split_count) {
2140   TF_RET_CHECK(split_count > 0);
2141   if (split_dimension >= shape.rank() || split_dimension < 0) {
2142     return InvalidArgument(
2143         "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2144         split_dimension, ShapeUtil::HumanString(shape));
2145   }
2146   if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2147     return InvalidArgument(
2148         "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2149         concat_dimension, ShapeUtil::HumanString(shape));
2150   }
2151   if (shape.dimensions(split_dimension) % split_count != 0) {
2152     return InvalidArgument(
2153         "AllToAll split dimension size %d must be dividable by split_count "
2154         "%d.",
2155         shape.dimensions(split_dimension), split_count);
2156   }
2157   std::vector<int64_t> new_dimensions(shape.dimensions().begin(),
2158                                       shape.dimensions().end());
2159   new_dimensions[split_dimension] /= split_count;
2160   new_dimensions[concat_dimension] *= split_count;
2161   return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2162 }
2163 
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2164 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2165     absl::Span<const Shape* const> operand_shapes) {
2166   // An Alltoall HLO instruction receives N operands (with the same shape) and
2167   // returns a tuple that contains N array shapes.
2168   TF_RET_CHECK(!operand_shapes.empty());
2169   for (int i = 0; i < operand_shapes.size(); i++) {
2170     if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
2171                                                     *operand_shapes[i])) {
2172       return InvalidArgument(
2173           "HLO all-to-all has operands with different shapes: the 0th "
2174           "operand shape %s, but the %dth operand has shape %s.",
2175           ShapeUtil::HumanString(*operand_shapes[0]), i,
2176           ShapeUtil::HumanString(*operand_shapes[i]));
2177     }
2178   }
2179 
2180   return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2181 }
2182 
InferCollectivePermuteShape(absl::Span<const Shape * const> operand_shapes)2183 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2184     absl::Span<const Shape* const> operand_shapes) {
2185   if (operand_shapes.size() == 1) {
2186     TF_RETURN_IF_ERROR(
2187         ExpectArray(*(operand_shapes[0]), "operand of collective-permute"));
2188     return *(operand_shapes[0]);
2189   } else {
2190     TF_RET_CHECK(operand_shapes.size() == 4);
2191     return *(operand_shapes[1]);
2192   }
2193 }
2194 
InferCollectivePermuteStartShape(absl::Span<const Shape * const> operand_shapes)2195 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteStartShape(
2196     absl::Span<const Shape* const> operand_shapes) {
2197   const Shape u32_scalar = ShapeUtil::MakeShape(U32, {});
2198   if (operand_shapes.size() == 1) {
2199     TF_RETURN_IF_ERROR(ExpectArray(*(operand_shapes[0]),
2200                                    "operand of collective-permute-start"));
2201     return ShapeUtil::MakeTupleShapeWithPtrs(
2202         {operand_shapes[0], operand_shapes[0], &u32_scalar, &u32_scalar});
2203   } else {
2204     TF_RET_CHECK(operand_shapes.size() == 4);
2205     return ShapeUtil::MakeTupleShapeWithPtrs(
2206         {operand_shapes[0], operand_shapes[1], &u32_scalar, &u32_scalar});
2207   }
2208 }
2209 
InferCollectivePermuteDoneShape(const Shape & operand_shape)2210 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteDoneShape(
2211     const Shape& operand_shape) {
2212   TF_RET_CHECK(operand_shape.IsTuple());
2213   return ShapeUtil::GetTupleElementShape(operand_shape, 1);
2214 }
2215 
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64_t> dimensions_to_reduce,const ProgramShape & to_apply)2216 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2217     absl::Span<const Shape* const> arg_shapes,
2218     absl::Span<const int64_t> dimensions_to_reduce,
2219     const ProgramShape& to_apply) {
2220   if (arg_shapes.empty()) {
2221     return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2222   }
2223   if (arg_shapes.size() % 2) {
2224     return InvalidArgument(
2225         "Reduce must have an even number of arguments, has %lu",
2226         arg_shapes.size());
2227   }
2228   int64_t num_reduced_args = arg_shapes.size() / 2;
2229   auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2230   // Check that all of the reduced tensors have the same dimensions. The element
2231   // types may be different.
2232   for (int64_t i = 1; i < num_reduced_args; ++i) {
2233     if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2234       return InvalidArgument(
2235           "All reduced tensors must have the same dimension. Tensor 0 has "
2236           "shape %s, Tensor %d has shape %s",
2237           ShapeUtil::HumanString(*reduced_args[0]), i,
2238           ShapeUtil::HumanString(*reduced_args[i]));
2239     }
2240   }
2241   // Check that the dimensions to reduce are in-bounds for the given shape.
2242   // We've already verified all reduced tensors have the same dimensions, so it
2243   // doesn't matter which one we choose.
2244   const Shape& arg = *reduced_args[0];
2245   for (int64_t dimension : dimensions_to_reduce) {
2246     if (dimension >= arg.rank() || dimension < 0) {
2247       return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2248                              dimension, ShapeUtil::HumanString(arg));
2249     }
2250   }
2251 
2252   auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2253   std::vector<PrimitiveType> element_types;
2254   element_types.reserve(reduced_args.size());
2255   for (const Shape* arg : reduced_args) {
2256     element_types.push_back(arg->element_type());
2257   }
2258   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2259                                         num_reduced_args));
2260 
2261   absl::flat_hash_set<int64_t> dimensions_to_reduce_set;
2262   for (int64_t dim_to_reduce : dimensions_to_reduce) {
2263     if (!dimensions_to_reduce_set.insert(dim_to_reduce).second) {
2264       return InvalidArgument("Duplicate reduction dimension: %d",
2265                              dim_to_reduce);
2266     }
2267   }
2268 
2269   std::vector<int64_t> new_dimensions;
2270   std::vector<bool> new_is_dynamic;
2271   for (int i = 0; i < arg.rank(); ++i) {
2272     if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2273       new_dimensions.push_back(arg.dimensions(i));
2274       new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2275     }
2276   }
2277 
2278   if (ShapeUtil::IsScalar(to_apply.result())) {
2279     return ShapeUtil::MakeShape(to_apply.result().element_type(),
2280                                 new_dimensions, new_is_dynamic);
2281   } else {
2282     std::vector<Shape> result_subshapes;
2283     const auto& tuple_shapes = to_apply.result().tuple_shapes();
2284     result_subshapes.reserve(tuple_shapes.size());
2285     for (const Shape& subshape : tuple_shapes) {
2286       result_subshapes.push_back(ShapeUtil::MakeShape(
2287           subshape.element_type(), new_dimensions, new_is_dynamic));
2288     }
2289     return ShapeUtil::MakeTupleShape(result_subshapes);
2290   }
2291 }
2292 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2293 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2294     const Shape& operand_shape, const Shape& init_value_shape,
2295     const Window& window, const ProgramShape& to_apply_shape) {
2296   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2297                                         {operand_shape.element_type()},
2298                                         /*inputs=*/1));
2299   return InferReduceWindowShape(operand_shape, init_value_shape, window);
2300 }
2301 
InferReduceWindowShape(absl::Span<const Shape * const> operands,absl::Span<const Shape * const> init_values,const Window & window,const ProgramShape & to_apply_shape)2302 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2303     absl::Span<const Shape* const> operands,
2304     absl::Span<const Shape* const> init_values, const Window& window,
2305     const ProgramShape& to_apply_shape) {
2306   auto number_of_input = operands.size();
2307   // Check that all of the reduced tensors have the same dimensions. The element
2308   // types may be different.
2309   for (int64_t i = 1; i < number_of_input; ++i) {
2310     if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
2311       return InvalidArgument(
2312           "All reduced tensors must have the same dimension. Tensor 0 has "
2313           "shape %s, Tensor %d has shape %s",
2314           ShapeUtil::HumanString(*operands[0]), i,
2315           ShapeUtil::HumanString(*operands[i]));
2316     }
2317   }
2318   std::vector<PrimitiveType> operand_element_type_vec;
2319   operand_element_type_vec.reserve(operands.size());
2320   for (const Shape* s : operands) {
2321     operand_element_type_vec.push_back(s->element_type());
2322   }
2323   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
2324                                         operand_element_type_vec,
2325                                         /*inputs=*/number_of_input));
2326   std::vector<Shape> output_shape_vec;
2327   const size_t n = operands.size();
2328   output_shape_vec.reserve(n);
2329   for (size_t i = 0; i < operands.size(); ++i) {
2330     TF_ASSIGN_OR_RETURN(
2331         auto cur_output_shape,
2332         InferReduceWindowShape(*operands[i], *init_values[i], window));
2333     output_shape_vec.push_back(cur_output_shape);
2334   }
2335   if (ShapeUtil::IsScalar(to_apply_shape.result())) {
2336     CHECK_EQ(output_shape_vec.size(), 1);
2337     return output_shape_vec[0];
2338   } else {
2339     return ShapeUtil::MakeTupleShape(output_shape_vec);
2340   }
2341 }
2342 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2343 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2344     const Shape& operand_shape, const Shape& init_value_shape,
2345     const Window& window) {
2346   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2347   return InferWindowOutputShape(operand_shape, window,
2348                                 init_value_shape.element_type());
2349 }
2350 
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2351 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2352     const Shape& operand_shape, const ProgramShape& select_shape,
2353     const Window& window, const Shape& source_shape,
2354     const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2355   TF_RETURN_IF_ERROR(
2356       ExpectArray(operand_shape, "operand of select-and-scatter"));
2357 
2358   // Check if the select function has a proper shape of (T,T) -> PRED.
2359   if (select_shape.parameters_size() != 2) {
2360     return InvalidArgument(
2361         "Select function must take 2 parameters, but "
2362         "takes %d parameter(s).",
2363         select_shape.parameters_size());
2364   }
2365   const Shape& select_result_shape = select_shape.result();
2366   if (!ShapeUtil::Compatible(select_result_shape,
2367                              ShapeUtil::MakeShape(PRED, {}))) {
2368     return InvalidArgument("Select function must have rank-0 PRED result.");
2369   }
2370   const Shape& operand_element_shape =
2371       ShapeUtil::MakeShape(operand_shape.element_type(), {});
2372   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2373                                                 select_shape.parameters(0))) {
2374     return InvalidArgument(
2375         "Select function's first parameter shape currently must "
2376         "match the operand element shape, but got %s vs %s.",
2377         ShapeUtil::HumanString(select_shape.parameters(0)),
2378         ShapeUtil::HumanString(operand_element_shape));
2379   }
2380   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2381                                                 select_shape.parameters(1))) {
2382     return InvalidArgument(
2383         "Select function's second parameter shape currently must "
2384         "match the operand element shape, but got %s vs %s.",
2385         ShapeUtil::HumanString(select_shape.parameters(1)),
2386         ShapeUtil::HumanString(operand_element_shape));
2387   }
2388 
2389   // Check if the scatter function has a proper shape as a reduction.
2390   TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2391                                         {source_shape.element_type()},
2392                                         /*inputs=*/1));
2393 
2394   // Check if the result shape of window operation matches the source shape.
2395   TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2396                       InferWindowOutputShape(operand_shape, window,
2397                                              operand_shape.element_type()));
2398   if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2399                                                 window_result_shape)) {
2400     return InvalidArgument(
2401         "Source shape does not match the shape of window-reduced operand: "
2402         "source(%s), window-reduced operand(%s).",
2403         ShapeUtil::HumanString(source_shape),
2404         ShapeUtil::HumanString(window_result_shape));
2405   }
2406 
2407   return operand_shape;
2408 }
2409 
InferGetDimensionSizeShape(const Shape & shape,int64_t dimension)2410 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2411     const Shape& shape, int64_t dimension) {
2412   if (dimension < 0 || dimension >= shape.rank()) {
2413     return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2414                            dimension);
2415   }
2416 
2417   // TODO(b/119580730): Remove this restriction when very large dimension size
2418   // is needed.
2419   if (shape.dimensions(dimension) > std::numeric_limits<int32_t>::max()) {
2420     return InvalidArgument(
2421         "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2422         "INT_MAX limit.",
2423         ShapeUtil::HumanString(shape), dimension);
2424   }
2425 
2426   return ShapeUtil::MakeShape(S32, {});
2427 }
2428 
InferSetDimensionSizeShape(const Shape & shape,const Shape & val_shape,int64_t dimension)2429 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2430     const Shape& shape, const Shape& val_shape, int64_t dimension) {
2431   if (dimension < 0 || dimension >= shape.rank()) {
2432     return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2433                            dimension);
2434   }
2435 
2436   if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
2437     return InvalidArgument(
2438         "SetDimensionSize's value has to be S32 scalar, got %s",
2439         val_shape.ToString());
2440   }
2441   // TODO(b/119580730): Remove this restriction when very large dimension size
2442   // is needed.
2443   if (shape.dimensions(dimension) > std::numeric_limits<int32_t>::max()) {
2444     return InvalidArgument(
2445         "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2446         "INT_MAX limit.",
2447         ShapeUtil::HumanString(shape), dimension);
2448   }
2449 
2450   Shape result = shape;
2451   result.set_dynamic_dimension(dimension, true);
2452   return result;
2453 }
2454 
InferWindowFromDimensions(absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,std::optional<std::vector<bool>> window_reversal)2455 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2456     absl::Span<const int64_t> window_dimensions,
2457     absl::Span<const int64_t> window_strides,
2458     absl::Span<const std::pair<int64_t, int64_t>> padding,
2459     absl::Span<const int64_t> lhs_dilation,
2460     absl::Span<const int64_t> rhs_dilation,
2461     std::optional<std::vector<bool>> window_reversal) {
2462   const auto verify_size = [&](const size_t x, const char* x_name) {
2463     if (x == 0 || x == window_dimensions.size()) {
2464       return OkStatus();
2465     } else {
2466       return InvalidArgument(
2467           "%s", absl::StrCat(
2468                     "Window has different number of window dimensions than of ",
2469                     x_name,
2470                     "\nNumber of window dimensions: ", window_dimensions.size(),
2471                     "\nNumber of ", x_name, ": ", x, "\n"));
2472     }
2473   };
2474   TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2475   TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2476   TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2477   TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2478   if (window_reversal.has_value()) {
2479     TF_RETURN_IF_ERROR(verify_size(window_reversal->size(), "window reversal"));
2480   }
2481 
2482   Window window;
2483   for (size_t i = 0; i < window_dimensions.size(); i++) {
2484     auto dim = window.add_dimensions();
2485     dim->set_size(window_dimensions[i]);
2486     if (!window_strides.empty()) {
2487       dim->set_stride(window_strides[i]);
2488     } else {
2489       dim->set_stride(1);
2490     }
2491     if (!padding.empty()) {
2492       dim->set_padding_low(padding[i].first);
2493       dim->set_padding_high(padding[i].second);
2494     } else {
2495       dim->set_padding_low(0);
2496       dim->set_padding_high(0);
2497     }
2498     if (!lhs_dilation.empty()) {
2499       dim->set_base_dilation(lhs_dilation[i]);
2500     } else {
2501       dim->set_base_dilation(1);
2502     }
2503     if (!rhs_dilation.empty()) {
2504       dim->set_window_dilation(rhs_dilation[i]);
2505     } else {
2506       dim->set_window_dilation(1);
2507     }
2508     if (window_reversal.has_value() && !window_reversal->empty()) {
2509       dim->set_window_reversal(window_reversal->at(i));
2510     } else {
2511       dim->set_window_reversal(false);
2512     }
2513   }
2514   return window;
2515 }
2516 
InferSliceShape(const Shape & arg,absl::Span<const int64_t> starts,absl::Span<const int64_t> limits,absl::Span<const int64_t> strides)2517 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2518     const Shape& arg, absl::Span<const int64_t> starts,
2519     absl::Span<const int64_t> limits, absl::Span<const int64_t> strides) {
2520   auto error = [&](const std::string& message) {
2521     return InvalidArgument(
2522         "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2523         "{%s}; strides: {%s}.",
2524         message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2525         StrJoin(limits, ","), StrJoin(strides, ","));
2526   };
2527   TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2528   VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2529                        ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2530                        StrJoin(limits, ", "));
2531 
2532   if (starts.size() != limits.size()) {
2533     return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2534                            starts.size(), limits.size()));
2535   }
2536 
2537   if (starts.size() != strides.size()) {
2538     return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2539                            starts.size(), strides.size()));
2540   }
2541 
2542   if (starts.size() != arg.rank()) {
2543     return InvalidArgument(
2544         "Slice index count does not match argument rank: %u vs %d.",
2545         starts.size(), arg.rank());
2546   }
2547 
2548   std::vector<int64_t> sizes;
2549   const auto starts_size = starts.size();
2550   sizes.reserve(starts_size);
2551   for (int64_t dimension = 0; dimension < starts_size; ++dimension) {
2552     int64_t start_index = starts[dimension];
2553     int64_t limit_index = limits[dimension];
2554     int64_t stride = strides[dimension];
2555     if (start_index < 0) {
2556       return InvalidArgument("Negative start index to slice: %d.", start_index);
2557     }
2558     if (limit_index > arg.dimensions(dimension)) {
2559       return error(
2560           StrFormat("limit index (%d) must be less than or equal to dimension "
2561                     "size (%d)",
2562                     limit_index, arg.dimensions(dimension)));
2563     }
2564     VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2565     VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2566     if (start_index > limit_index) {
2567       return error(
2568           StrFormat("limit index (%d) must be greater or equal to "
2569                     "start index (%d) in slice with positive stride",
2570                     limit_index, start_index));
2571     }
2572     if (stride <= 0) {
2573       return InvalidArgument("Stride (%d) must be positive.", stride);
2574     }
2575     sizes.push_back((limit_index - start_index + stride - 1) / stride);
2576   }
2577 
2578   std::vector<bool> is_dynamic(arg.rank());
2579   for (int64_t i = 0; i < arg.dimensions_size(); ++i) {
2580     // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2581     if (sizes[i] == 1) {
2582       continue;
2583     }
2584     is_dynamic[i] = arg.is_dynamic_dimension(i);
2585   }
2586 
2587   return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2588 }
2589 
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64_t> slice_sizes,bool allow_scalar_indices)2590 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2591     const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2592     absl::Span<const int64_t> slice_sizes, bool allow_scalar_indices) {
2593   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2594   auto number_of_indices = start_index_shapes.size();
2595   // TODO(b/118437727): Remove this path.
2596   if (!allow_scalar_indices ||
2597       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2598     if (number_of_indices != 1) {
2599       return InvalidArgument(
2600           "Dynamic slice should have exactly 1 index operand, has %d.",
2601           number_of_indices);
2602     }
2603 
2604     const Shape& start_indices_shape = start_index_shapes[0];
2605     VLOG(2) << StrFormat(
2606         "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2607         ShapeUtil::HumanString(operand_shape),
2608         ShapeUtil::HumanString(start_indices_shape),
2609         StrJoin(slice_sizes, ", "));
2610 
2611     TF_RETURN_IF_ERROR(
2612         ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2613 
2614     if (start_indices_shape.rank() != 1) {
2615       return InvalidArgument(
2616           "Dynamic slice start indices of rank %d must be rank1.",
2617           start_indices_shape.rank());
2618     }
2619 
2620     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2621       return InvalidArgument(
2622           "Dynamic slice start indices must be of integral type.");
2623     }
2624 
2625     const int64_t start_num_dims = start_indices_shape.dimensions(0);
2626     if (operand_shape.rank() != start_num_dims) {
2627       return InvalidArgument(
2628           "Dynamic slice start number of dimensions %d (%s) must match rank "
2629           "%d of slice input (%s).",
2630           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2631           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2632     }
2633   } else {
2634     VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2635                          ShapeUtil::HumanString(operand_shape),
2636                          StrJoin(slice_sizes, ", "));
2637 
2638     if (operand_shape.rank() != number_of_indices) {
2639       return InvalidArgument(
2640           "Dynamic slice start number of dimensions %d must match rank "
2641           "%d of slice input (%s).",
2642           number_of_indices, operand_shape.rank(),
2643           ShapeUtil::HumanString(operand_shape));
2644     }
2645 
2646     if (number_of_indices > 0) {
2647       const Shape& first_index_shape = start_index_shapes[0];
2648       if (!ShapeUtil::IsScalar(first_index_shape)) {
2649         return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2650                                ShapeUtil::HumanString(first_index_shape));
2651       }
2652       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2653         return InvalidArgument(
2654             "Dynamic slice start indices must be of integral type.");
2655       }
2656       for (const Shape& index_shape : start_index_shapes) {
2657         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2658           return InvalidArgument(
2659               "Dynamic slice start indices must all have the same shape, got "
2660               "mismatching indices with shapes %s and %s.",
2661               ShapeUtil::HumanString(first_index_shape),
2662               ShapeUtil::HumanString(index_shape));
2663         }
2664       }
2665     }
2666   }
2667 
2668   if (slice_sizes.size() != operand_shape.rank()) {
2669     return InvalidArgument(
2670         "Dynamic slice index count does not match argument rank: %u vs %d.",
2671         slice_sizes.size(), operand_shape.rank());
2672   }
2673 
2674   for (int64_t dim = 0; dim < slice_sizes.size(); ++dim) {
2675     const int64_t input_dim_size = operand_shape.dimensions(dim);
2676     const int64_t slice_dim_size = slice_sizes[dim];
2677     if (slice_dim_size < 0) {
2678       return InvalidArgument("Negative size index to dynamic slice: %d.",
2679                              slice_dim_size);
2680     }
2681     if (slice_dim_size > input_dim_size) {
2682       return InvalidArgument(
2683           "Slice dim size %d greater than dynamic slice dimension: %d.",
2684           slice_dim_size, input_dim_size);
2685     }
2686     VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2687   }
2688 
2689   return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2690 }
2691 
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2692 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2693     const Shape& operand_shape, const Shape& update_shape,
2694     absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2695   TF_RETURN_IF_ERROR(
2696       ExpectArray(operand_shape, "operand of dynamic update slice"));
2697   TF_RETURN_IF_ERROR(
2698       ExpectArray(update_shape, "update of dynamic update slice"));
2699 
2700   auto number_of_indices = start_index_shapes.size();
2701   // TODO(b/118437727): Remove this path.
2702   if (!allow_scalar_indices ||
2703       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2704     if (number_of_indices != 1) {
2705       return InvalidArgument(
2706           "Dynamic update slice should have exactly 1 index operand, has %d.",
2707           number_of_indices);
2708     }
2709     const Shape& start_indices_shape = start_index_shapes[0];
2710     TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2711                                    "start indices of dynamic update slice"));
2712 
2713     VLOG(2) << StrFormat(
2714         "updating slice of shape %s at dynamic start_indices %s with update "
2715         "shape %s",
2716         ShapeUtil::HumanString(operand_shape),
2717         ShapeUtil::HumanString(start_indices_shape),
2718         ShapeUtil::HumanString(update_shape));
2719 
2720     if (start_indices_shape.rank() != 1) {
2721       return InvalidArgument(
2722           "Dynamic update slice start indices of rank %d must be rank1.",
2723           start_indices_shape.rank());
2724     }
2725 
2726     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2727       return InvalidArgument(
2728           "Dynamic update slice start indices must be of integral type.");
2729     }
2730 
2731     const int64_t start_num_dims = start_indices_shape.dimensions(0);
2732     if (operand_shape.rank() != start_num_dims) {
2733       return InvalidArgument(
2734           "Dynamic update slice start number of dimensions %d (%s) must match "
2735           "rank %d of slice input (%s).",
2736           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2737           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2738     }
2739   } else {
2740     VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2741                          ShapeUtil::HumanString(operand_shape),
2742                          ShapeUtil::HumanString(update_shape));
2743 
2744     if (operand_shape.rank() != number_of_indices) {
2745       return InvalidArgument(
2746           "Dynamic update slice start number of dimensions %d must match "
2747           "rank %d of slice input (%s).",
2748           number_of_indices, operand_shape.rank(),
2749           ShapeUtil::HumanString(operand_shape));
2750     }
2751 
2752     if (number_of_indices > 0) {
2753       const Shape& first_index_shape = start_index_shapes[0];
2754       if (!ShapeUtil::IsScalar(first_index_shape)) {
2755         return InvalidArgument(
2756             "Dynamic update slice indices must be scalar, not %s.",
2757             ShapeUtil::HumanString(first_index_shape));
2758       }
2759       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2760         return InvalidArgument(
2761             "Dynamic update slice start indices must be of integral type.");
2762       }
2763       for (const Shape& index_shape : start_index_shapes) {
2764         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2765           return InvalidArgument(
2766               "Dynamic update slice start indices must all have the same "
2767               "shape, got mismatching indices with shapes %s and %s.",
2768               ShapeUtil::HumanString(first_index_shape),
2769               ShapeUtil::HumanString(index_shape));
2770         }
2771       }
2772     }
2773   }
2774 
2775   if (update_shape.rank() != operand_shape.rank()) {
2776     return InvalidArgument(
2777         "Dynamic update slice update rank does not match argument rank: "
2778         "%d vs %d.",
2779         update_shape.rank(), operand_shape.rank());
2780   }
2781 
2782   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2783                                                      update_shape)) {
2784     return InvalidArgument(
2785         "Dynamic update slice update element type does not match argument. "
2786         "operand.element_type: %s vs update.element_type: %s.",
2787         PrimitiveType_Name(operand_shape.element_type()),
2788         PrimitiveType_Name(update_shape.element_type()));
2789   }
2790 
2791   for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) {
2792     const int64_t input_dim_size = operand_shape.dimensions(dim);
2793     const int64_t update_dim_size = update_shape.dimensions(dim);
2794     if (update_dim_size < 0) {
2795       return InvalidArgument(
2796           "Size index %d to dynamic update slice must be >= 0.",
2797           update_dim_size);
2798     }
2799     if (update_dim_size > input_dim_size) {
2800       return InvalidArgument(
2801           "Update dim size %d greater than dynamic slice dimension: %d.",
2802           update_dim_size, input_dim_size);
2803     }
2804     VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2805   }
2806 
2807   auto result_shape = operand_shape;
2808 
2809   // If any of the operand shape is dynamic, the result dimension is also
2810   // dynamic.
2811   // If update shape is dynamic, only propagate dynamic dimension to result if
2812   // the update is a full update (update_shape[i] == operand_shape[i]).
2813   for (int64_t i = 0; i < update_shape.rank(); ++i) {
2814     if (operand_shape.is_dynamic_dimension(i)) {
2815       result_shape.set_dynamic_dimension(i, true);
2816     }
2817 
2818     if (update_shape.is_dynamic_dimension(i) &&
2819         update_shape.dimensions(i) == operand_shape.dimensions(i)) {
2820       // When update/replace a full dimension, propagate dynamic dimension to
2821       // the result.
2822       result_shape.set_dynamic_dimension(i, true);
2823     }
2824   }
2825 
2826   return result_shape;
2827 }
2828 
InferReverseShape(const Shape & operand_shape,absl::Span<const int64_t> dimensions)2829 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2830     const Shape& operand_shape, absl::Span<const int64_t> dimensions) {
2831   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2832   if (!AllUnique(dimensions)) {
2833     return InvalidArgument("a dimension number is duplicated in reverse");
2834   }
2835   for (int64_t dimension : dimensions) {
2836     if (dimension >= operand_shape.rank() || dimension < 0) {
2837       return InvalidArgument(
2838           "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2839           dimension, ShapeUtil::HumanString(operand_shape));
2840     }
2841   }
2842   return operand_shape;
2843 }
2844 
InferGetTupleElementShape(const Shape & arg,int64_t index)2845 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2846     const Shape& arg, int64_t index) {
2847   if (!arg.IsTuple()) {
2848     return InvalidArgument(
2849         "Cannot infer shape: attempting to index into non-tuple: %s.",
2850         ShapeUtil::HumanString(arg));
2851   }
2852 
2853   if (index < 0 || index >= arg.tuple_shapes_size()) {
2854     return InvalidArgument(
2855         "Cannot infer shape: attempt to index out of tuple bounds: %d "
2856         ">= %d in shape %s.",
2857         index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2858   }
2859 
2860   return arg.tuple_shapes(index);
2861 }
2862 
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2863 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2864     const ProgramShape& condition, const ProgramShape& body,
2865     const Shape& init) {
2866   // Check the number of parameters for given computations.
2867   if (condition.parameters_size() != 1) {
2868     return InvalidArgument("Condition must take 1 arguments; got %d.",
2869                            condition.parameters_size());
2870   }
2871   if (body.parameters_size() != 1) {
2872     return InvalidArgument("Body must take 1 arguments; got %d.",
2873                            body.parameters_size());
2874   }
2875 
2876   auto shape_string = [&]() {
2877     return StrFormat(
2878         "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2879         ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2880   };
2881 
2882   // Check the shapes of computation parameters and return types.
2883   if (!ShapeUtil::Compatible(condition.result(),
2884                              ShapeUtil::MakeShape(PRED, {}))) {
2885     return InvalidArgument("Condition must return a boolean; got %s.",
2886                            shape_string());
2887   }
2888   if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2889       !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2890       !ShapeUtil::Compatible(body.result(), init)) {
2891     return InvalidArgument(
2892         "The parameter of condition and body, the result of the body, and init "
2893         "must all have the same shape; got %s.",
2894         shape_string());
2895   }
2896 
2897   return init;
2898 }
2899 
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2900 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2901     const Shape& branch_index,
2902     absl::Span<const ProgramShape> branch_computations,
2903     absl::Span<const Shape> branch_operands) {
2904   if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2905       !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2906     return InvalidArgument("branch_index must be bool or int32_t; got %s.",
2907                            ShapeUtil::HumanString(branch_index));
2908   }
2909   if (branch_index.element_type() == PRED) {
2910     TF_RET_CHECK(2 == branch_computations.size());
2911   } else {
2912     TF_RET_CHECK(!branch_computations.empty());
2913   }
2914   TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2915   Shape result = branch_computations[0].result();
2916   for (int j = 0; j < branch_computations.size(); ++j) {
2917     if (branch_computations[j].parameters_size() != 1) {
2918       return InvalidArgument(
2919           "branch computation %d must take 1 argument; got %d.", j,
2920           branch_computations[j].parameters_size());
2921     }
2922     if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2923                                branch_operands[j])) {
2924       auto shape_string = [&]() {
2925         return StrFormat("operand: %s; computation: %s",
2926                          ShapeUtil::HumanString(branch_operands[j]),
2927                          ShapeUtil::HumanString(branch_computations[j]));
2928       };
2929       return InvalidArgument(
2930           "branch operand %d must match the shape of the only parameter of "
2931           "branch computation %d: got %s.",
2932           j, j, shape_string());
2933     }
2934 
2935     if (!ShapeUtil::Compatible(branch_computations[0].result(),
2936                                branch_computations[j].result())) {
2937       auto shape_string = [&]() {
2938         return StrFormat(
2939             "branch 0 computation result: %s; branch %d computation result: %s",
2940             ShapeUtil::HumanString(branch_computations[0].result()), j,
2941             ShapeUtil::HumanString(branch_computations[j].result()));
2942       };
2943       return InvalidArgument(
2944           "the result of branch 0 computation and branch %d computation must "
2945           "have the same shape: got %s.",
2946           j, shape_string());
2947     }
2948   }
2949   // For each subshape, If any of the branch is dynamic, we say result is
2950   // dynamic:
2951   //
2952   //   true_branch  (s32[<=4])
2953   //   false_branch (s32[4])
2954   //
2955   // Result is s32[<=4].
2956   ShapeUtil::ForEachMutableSubshape(
2957       &result, [&](Shape* subshape, const ShapeIndex& index) {
2958         if (!subshape->IsArray()) {
2959           return;
2960         }
2961         for (int j = 0; j < branch_computations.size(); ++j) {
2962           auto branch_subshape =
2963               ShapeUtil::GetSubshape(branch_computations[j].result(), index);
2964           for (int64_t i = 0; i < branch_subshape.rank(); ++i) {
2965             if (branch_subshape.is_dynamic_dimension(i)) {
2966               subshape->set_dynamic_dimension(i, true);
2967             }
2968           }
2969         }
2970       });
2971 
2972   return result;
2973 }
2974 
InferBroadcastShape(const Shape & operand,absl::Span<const int64_t> broadcast_sizes)2975 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2976     const Shape& operand, absl::Span<const int64_t> broadcast_sizes) {
2977   TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2978   for (int64_t size : broadcast_sizes) {
2979     if (size < 0) {
2980       return InvalidArgument("Broadcast with negative dimension size %d.",
2981                              size);
2982     }
2983   }
2984 
2985   std::vector<int64_t> dimensions(operand.dimensions_size() +
2986                                   broadcast_sizes.size());
2987   std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2988   std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2989             dimensions.begin() + broadcast_sizes.size());
2990 
2991   Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2992   for (int64_t i = 0; i < operand.dimensions_size(); ++i) {
2993     result.set_dynamic_dimension(broadcast_sizes.size() + i,
2994                                  operand.is_dynamic_dimension(i));
2995   }
2996   return result;
2997 }
2998 
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64_t> broadcast_dimensions)2999 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
3000     const Shape& operand_shape, const Shape& output_shape,
3001     absl::Span<const int64_t> broadcast_dimensions) {
3002   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
3003   TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
3004   const int64_t operand_rank = operand_shape.rank();
3005   const int64_t output_rank = output_shape.rank();
3006   if (operand_rank > output_rank) {
3007     return InvalidArgument(
3008         "InDim style broadcast must be to an equal or higher ranked shape; "
3009         "operand rank: %lld; output rank: %lld",
3010         operand_rank, output_rank);
3011   }
3012   if (operand_rank != broadcast_dimensions.size()) {
3013     return InvalidArgument(
3014         "Size of broadcast_dimensions has to match operand's rank; operand "
3015         "rank: %lld, size of broadcast_dimensions %u.",
3016         operand_rank, broadcast_dimensions.size());
3017   }
3018   for (int64_t i = 0; i < operand_rank; i++) {
3019     if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
3020       return InvalidArgument("Broadcast dimension %lld is out of bound",
3021                              broadcast_dimensions[i]);
3022     }
3023     if (operand_shape.dimensions(i) !=
3024             output_shape.dimensions(broadcast_dimensions[i]) &&
3025         operand_shape.dimensions(i) != 1) {
3026       return InvalidArgument(
3027           "Input dimension should be either 1 or equal to the output dimension "
3028           "it is broadcasting into; the %lldth operand dimension is %lld, the "
3029           "%lldth output dimension is %lld.",
3030           i, operand_shape.dimensions(i), broadcast_dimensions[i],
3031           output_shape.dimensions(broadcast_dimensions[i]));
3032     }
3033     if (operand_shape.is_dynamic_dimension(i) !=
3034         output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
3035       return InvalidArgument(
3036           "Broadcast input and output dynamism mismatch: %s and %s",
3037           operand_shape.ToString(), output_shape.ToString());
3038     }
3039     // Make sure the broadcast dimensions are listed in a strictly increasing
3040     // order.
3041     if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
3042       return InvalidArgument(
3043           "Broadcast dimensions order is wrong: %d comes after %d.",
3044           broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
3045     }
3046   }
3047 
3048   return output_shape;
3049 }
3050 
InferDynamicReshapeShape(const Shape & operand,absl::Span<const Shape * const> dim_size_shapes,absl::Span<const int64_t> new_size_bounds,const std::vector<bool> & dims_are_dynamic)3051 /* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
3052     const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
3053     absl::Span<const int64_t> new_size_bounds,
3054     const std::vector<bool>& dims_are_dynamic) {
3055   if (new_size_bounds.size() != dims_are_dynamic.size()) {
3056     return InvalidArgument(
3057         "DynamicReshape has to have the same number of elements in new_sizes "
3058         "(%d) and dims_are_dynamic (%d)",
3059         new_size_bounds.size(), dims_are_dynamic.size());
3060   }
3061 
3062   for (const Shape* dim_size_shape : dim_size_shapes) {
3063     if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
3064       return InvalidArgument(
3065           "DynamicReshape's dim size has to be scalar S32, got (%s): ",
3066           dim_size_shape->ToString());
3067     }
3068   }
3069 
3070   Shape inferred_shape = ShapeUtil::MakeShape(
3071       operand.element_type(), new_size_bounds, dims_are_dynamic);
3072   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3073     return InvalidArgument(
3074         "Reshape operation has mismatched element counts: from=%d (%s) "
3075         "to=%d (%s).",
3076         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3077         ShapeUtil::ElementsIn(inferred_shape),
3078         ShapeUtil::HumanString(inferred_shape));
3079   }
3080   return inferred_shape;
3081 }
3082 
InferReshapeShape(const Shape & operand,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)3083 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
3084     const Shape& operand, absl::Span<const int64_t> dimensions,
3085     absl::Span<const int64_t> new_sizes, int64_t inferred_dimension) {
3086   TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
3087 
3088   Shape inferred_shape =
3089       ShapeUtil::MakeShape(operand.element_type(), new_sizes);
3090   VLOG(3) << "Reshape inferred shape: "
3091           << ShapeUtil::HumanString(inferred_shape);
3092 
3093   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3094     return InvalidArgument(
3095         "Reshape operation has mismatched element counts: from=%d (%s) "
3096         "to=%d (%s).",
3097         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3098         ShapeUtil::ElementsIn(inferred_shape),
3099         ShapeUtil::HumanString(inferred_shape));
3100   }
3101 
3102   std::vector<int64_t> indices(operand.rank());
3103   std::iota(indices.begin(), indices.end(), 0);
3104   if (dimensions.size() != operand.rank() ||
3105       !std::is_permutation(dimensions.begin(), dimensions.end(),
3106                            indices.begin())) {
3107     return InvalidArgument(
3108         "Reshape dimensions [%s] are not a permutation of the operand "
3109         "dimensions (operand shape is %s).",
3110         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3111   }
3112 
3113   // Propagate dynamic dimension.
3114   auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
3115   for (int64_t input_dim = 0; input_dim < operand.rank(); ++input_dim) {
3116     if (!operand.is_dynamic_dimension(input_dim)) {
3117       continue;
3118     }
3119 
3120     std::string reshape_debug_str = absl::StrFormat(
3121         "output: %s, input: %s, input_dim: "
3122         "%lld",
3123         ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
3124         input_dim);
3125 
3126     int64_t input_dim_start = -1;
3127     int64_t input_dim_end = -1;
3128     int64_t output_dim_start = -1;
3129     int64_t output_dim_end = -1;
3130     // Find common_factors that the input_dim belongs to.
3131     for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
3132       auto start = common_factors[i];
3133       auto end = common_factors[i + 1];
3134       if (input_dim >= start.first && input_dim < end.first) {
3135         input_dim_start = start.first;
3136         input_dim_end = end.first;
3137         output_dim_start = start.second;
3138         output_dim_end = end.second;
3139         break;
3140       }
3141     }
3142     if ((input_dim_end - input_dim_start) > 1 &&
3143         (output_dim_end - output_dim_start) > 1) {
3144       // We don't support the case when a dynamic dimension is both combined
3145       // with and splitted into other dimensions:
3146       //
3147       //  [x, yz]
3148       //     | Reshape
3149       //  [xy, z]
3150       return Unimplemented(
3151           "Dynamic input dimension to reshape that is both splitted and "
3152           "combined is not supported: %s",
3153           reshape_debug_str);
3154     }
3155 
3156     for (auto common_factor : common_factors) {
3157       //
3158       // For reshapes like:
3159       //  [<=5]
3160       //    | Reshape
3161       //  [1, 5]
3162       //
3163       //  The input dynamic dimension can go into either dynamic dimensions.
3164       //  However, the return value of common factors only returns
3165       //  input: 5
3166       //  output: 5
3167       //
3168       //  We need to expand common factor to include degenerated output
3169       //  dimensions:
3170       //  input: 5
3171       //  output: 1, 5
3172       //
3173       //  such that in the logic later on we can consider both dimensions as
3174       //  candidate.
3175       if (common_factor.first == input_dim_start) {
3176         output_dim_start = std::min(output_dim_start, common_factor.second);
3177       }
3178       if (common_factor.first == input_dim_end) {
3179         output_dim_end = std::max(output_dim_end, common_factor.second);
3180       }
3181     }
3182 
3183     // Calculate output dynamic reshape dimension.
3184     int64_t output_dynamic_dimension = -1;
3185 
3186     if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
3187       // If dynamic dimension is size 1, it can only be most-major or
3188       // most-minor.
3189       if (input_dim == 0) {
3190         output_dynamic_dimension = 0;
3191       }
3192       if (input_dim == operand.rank() - 1) {
3193         output_dynamic_dimension = new_sizes.size() - 1;
3194       }
3195 
3196       if (output_dynamic_dimension == -1) {
3197         return Unimplemented(
3198             "Dynamic degenerated dimension that's not most-minor nor "
3199             "most-major is not supported: %s",
3200             reshape_debug_str);
3201       }
3202     }
3203 
3204     if (output_dynamic_dimension == -1 &&
3205         output_dim_end - output_dim_start == 1) {
3206       // Only one possible output dimension.
3207       output_dynamic_dimension = output_dim_start;
3208     }
3209     if (output_dynamic_dimension == -1 &&
3210         output_dim_end - output_dim_start > 1) {
3211       // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
3212       output_dynamic_dimension = inferred_dimension;
3213     }
3214 
3215     if (output_dynamic_dimension != -1) {
3216       // TODO(yunxing): Turn this into a CHECK.
3217       inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
3218     } else {
3219       std::vector<int64_t> output_non_degenerated;
3220       output_non_degenerated.reserve(output_dim_end);
3221       for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
3222         if (new_sizes[i] != 1) {
3223           output_non_degenerated.push_back(i);
3224         }
3225       }
3226       if (output_non_degenerated.size() == 1) {
3227         inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
3228       }
3229     }
3230   }
3231 
3232   return inferred_shape;
3233 }
3234 
InferTransposeShape(const Shape & operand,absl::Span<const int64_t> dimensions)3235 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
3236     const Shape& operand, absl::Span<const int64_t> dimensions) {
3237   TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
3238 
3239   if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
3240     return InvalidArgument(
3241         "Transpose dimensions [%s] are not a permutation of the operand "
3242         "dimensions (operand shape is %s).",
3243         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3244   }
3245 
3246   // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
3247   // we need output[i]=input[dimensions[i]] which is
3248   // Permute(Inverse(dimensions),input).
3249   return ShapeUtil::PermuteDimensions(dimensions, operand);
3250 }
3251 
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)3252 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
3253     const Shape& min, const Shape& operand, const Shape& max) {
3254   TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
3255   TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
3256   TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
3257 
3258   if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
3259       !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
3260     return InvalidArgument(
3261         "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
3262         ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
3263   }
3264   return operand;
3265 }
3266 
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3267 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
3268     const Shape& pred, const Shape& on_true, const Shape& on_false) {
3269   TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
3270   TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
3271   TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
3272 
3273   if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
3274     return InvalidArgument(
3275         "Operands to select must be the same shape; got %s and %s.",
3276         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3277   }
3278   if (pred.element_type() != PRED) {
3279     return InvalidArgument(
3280         "Select's pred operand must have PRED element type; got %s.",
3281         ShapeUtil::HumanString(pred));
3282   }
3283   if (!Shape::Equal()
3284            .IgnoreElementType()
3285            .IgnoreLayout()
3286            .IgnoreDynamicDimension()(pred, on_true)) {
3287     return InvalidArgument(
3288         "Operands to select and predicate must be the same shape; got %s and "
3289         "%s.",
3290         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3291   }
3292 
3293   return ShapeUtil::ChangeElementType(
3294       pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3295 }
3296 
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3297 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3298     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3299   // The applied function's arity equals the number of arguments.
3300   if (arg_shapes.size() != to_apply.parameters_size()) {
3301     std::string computation_signature = ShapeUtil::HumanString(to_apply);
3302     std::string argument_shapes =
3303         StrJoin(arg_shapes, ", ", [](std::string* out, const Shape* shape) {
3304           absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3305         });
3306     return InvalidArgument(
3307         "Call applied function arity must match number of arguments; got: "
3308         "arity: %d, arguments: %u; computation signature: %s; argument "
3309         "shapes: [%s].",
3310         to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3311         argument_shapes);
3312   }
3313 
3314   // All arguments must be compatible with the program shape.
3315   for (int i = 0; i < arg_shapes.size(); ++i) {
3316     const Shape& arg_shape = *arg_shapes[i];
3317     const Shape& param_shape = to_apply.parameters(i);
3318     if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3319       return InvalidArgument(
3320           "Call parameter must match argument; got parameter %d shape: %s, "
3321           "argument shape: %s.",
3322           i, ShapeUtil::HumanString(param_shape),
3323           ShapeUtil::HumanString(arg_shape));
3324     }
3325   }
3326 
3327   return to_apply.result();
3328 }
3329 
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64_t> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3330 static Status ValidateGatherDimensionNumbers(
3331     const Shape& input_shape, absl::Span<const int64_t> start_indices_shape,
3332     const GatherDimensionNumbers& dim_numbers) {
3333   if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3334     return InvalidArgument(
3335         "Output window dimensions in gather op must be ascending; got: %s.",
3336         StrJoin(dim_numbers.offset_dims(), ", "));
3337   }
3338 
3339   if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3340       dim_numbers.offset_dims().end()) {
3341     return InvalidArgument(
3342         "Output window dimensions in gather op must not repeat; got: %s.",
3343         StrJoin(dim_numbers.offset_dims(), ", "));
3344   }
3345 
3346   const int64_t output_offset_dim_count = dim_numbers.offset_dims_size();
3347   const int64_t output_shape_rank =
3348       output_offset_dim_count + start_indices_shape.size() - 1;
3349 
3350   for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3351     int64_t offset_dim = dim_numbers.offset_dims(i);
3352     if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3353       return InvalidArgument(
3354           "Offset dimension %d in gather op is out of bounds; got %d, but "
3355           "should "
3356           "have been in [0,%d).",
3357           i, offset_dim, output_shape_rank);
3358     }
3359   }
3360 
3361   if (dim_numbers.start_index_map_size() !=
3362       start_indices_shape[dim_numbers.index_vector_dim()]) {
3363     return InvalidArgument(
3364         "Gather op has %d elements in start_index_map and the "
3365         "bound of dimension index_vector_dim=%d of start_indices is "
3366         "%d. These two numbers must be equal.",
3367         dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3368         start_indices_shape[dim_numbers.index_vector_dim()]);
3369   }
3370 
3371   for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3372     int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3373     if (operand_dim_for_start_index_i < 0 ||
3374         operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3375       return InvalidArgument(
3376           "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3377           input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3378     }
3379   }
3380 
3381   std::vector<int64_t> sorted_start_index_map(
3382       dim_numbers.start_index_map().begin(),
3383       dim_numbers.start_index_map().end());
3384 
3385   absl::c_sort(sorted_start_index_map);
3386 
3387   if (absl::c_adjacent_find(sorted_start_index_map) !=
3388       sorted_start_index_map.end()) {
3389     return InvalidArgument(
3390         "Repeated dimensions are not allowed in start_index_map; "
3391         "got: %s.",
3392         StrJoin(dim_numbers.start_index_map(), ", "));
3393   }
3394 
3395   for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3396     if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3397       return InvalidArgument(
3398           "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3399           "%d), got: %d.",
3400           input_shape.dimensions_size(), collapsed_dim);
3401     }
3402   }
3403 
3404   if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3405     return InvalidArgument(
3406         "collapsed_slice_dims in gather op must be sorted; got: %s",
3407         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3408   }
3409 
3410   if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3411       dim_numbers.collapsed_slice_dims().end()) {
3412     return InvalidArgument(
3413         "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3414         "got: %s.",
3415         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3416   }
3417 
3418   return OkStatus();
3419 }
3420 
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64_t> slice_sizes)3421 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3422     const Shape& input_shape, const Shape& start_indices_shape,
3423     const GatherDimensionNumbers& gather_dim_numbers,
3424     absl::Span<const int64_t> slice_sizes) {
3425   TF_RETURN_IF_ERROR(
3426       ExpectArray(input_shape, "input tensor operand of gather op"));
3427   TF_RETURN_IF_ERROR(
3428       ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3429 
3430   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3431     return InvalidArgument(
3432         "Gather indices parameter must be an integral tensor; got %s.",
3433         ShapeUtil::HumanString(start_indices_shape));
3434   }
3435 
3436   // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3437   // index_vector_dim is rank(P).  The bounds of this expanded shape is
3438   // stored in expanded_start_indices_shape.
3439 
3440   if (start_indices_shape.dimensions_size() <
3441           gather_dim_numbers.index_vector_dim() ||
3442       gather_dim_numbers.index_vector_dim() < 0) {
3443     return InvalidArgument(
3444         "Gather index leaf dimension must be within [0, rank(start_indices) + "
3445         "1). rank(start_indices) is %d and gather index leaf dimension is "
3446         "%d.",
3447         start_indices_shape.dimensions_size(),
3448         gather_dim_numbers.index_vector_dim());
3449   }
3450 
3451   std::vector<int64_t> expanded_start_indices_shape;
3452   // Also tracks if an output dimension is dynamic.
3453   std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3454   expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3455   expanded_start_indices_shape_dynamic_dimensions.reserve(
3456       start_indices_shape.dimensions_size());
3457   absl::c_copy(start_indices_shape.dimensions(),
3458                std::back_inserter(expanded_start_indices_shape));
3459   absl::c_copy(
3460       start_indices_shape.dynamic_dimensions(),
3461       std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3462   if (expanded_start_indices_shape.size() ==
3463       gather_dim_numbers.index_vector_dim()) {
3464     expanded_start_indices_shape.push_back(1);
3465     expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3466   }
3467 
3468   TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3469       input_shape, expanded_start_indices_shape, gather_dim_numbers));
3470 
3471   if (slice_sizes.size() != input_shape.dimensions_size()) {
3472     return InvalidArgument(
3473         "Gather op must have one slice size for every input dimension; got: "
3474         "len(slice_sizes)=%lu, input_shape.rank=%d.",
3475         slice_sizes.size(), input_shape.dimensions_size());
3476   }
3477 
3478   if (slice_sizes.size() !=
3479       gather_dim_numbers.offset_dims_size() +
3480           gather_dim_numbers.collapsed_slice_dims_size()) {
3481     return InvalidArgument(
3482         "All components of the offset index in a gather op must either be a "
3483         "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3484         "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3485         slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3486         StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3487   }
3488 
3489   for (int i = 0; i < slice_sizes.size(); i++) {
3490     int64_t slice_size = slice_sizes[i];
3491     int64_t corresponding_input_size = input_shape.dimensions(i);
3492     if (slice_size < 0 || slice_size > corresponding_input_size) {
3493       return InvalidArgument(
3494           "Slice size at index %d in gather op is out of range, must be "
3495           "within [0, %d), got %d.",
3496           i, corresponding_input_size + 1, slice_size);
3497     }
3498   }
3499 
3500   for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3501     if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3502       return InvalidArgument(
3503           "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3504           "is %d for index %d at position %d.",
3505           slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3506           gather_dim_numbers.collapsed_slice_dims(i), i);
3507     }
3508   }
3509 
3510   int64_t result_rank = gather_dim_numbers.offset_dims_size() +
3511                         (expanded_start_indices_shape.size() - 1);
3512   int64_t offset_dims_seen = 0;
3513   int64_t gather_dims_seen = 0;
3514   std::vector<int64_t> output_dim_bounds;
3515   output_dim_bounds.reserve(result_rank);
3516 
3517   std::vector<bool> output_dim_is_dynamic;
3518   output_dim_is_dynamic.reserve(result_rank);
3519   for (int64_t i = 0; i < result_rank; i++) {
3520     int64_t current_bound;
3521     bool dim_dynamic = false;
3522     bool is_window_index =
3523         absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3524     if (is_window_index) {
3525       while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3526                                    offset_dims_seen)) {
3527         offset_dims_seen++;
3528       }
3529       // Gathering an entire dynamic dimension creates dynamic dimension.
3530       //
3531       // e.g.,:
3532       //
3533       // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3534       //
3535       // creates
3536       //
3537       // [<=2, 1]
3538       if (slice_sizes[offset_dims_seen] ==
3539           input_shape.dimensions(offset_dims_seen)) {
3540         dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3541       }
3542       current_bound = slice_sizes[offset_dims_seen++];
3543     } else {
3544       if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3545         gather_dims_seen++;
3546       }
3547       // Forward dynamic dimensions from indices.
3548       dim_dynamic =
3549           expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3550 
3551       current_bound = expanded_start_indices_shape[gather_dims_seen++];
3552     }
3553     output_dim_is_dynamic.push_back(dim_dynamic);
3554     output_dim_bounds.push_back(current_bound);
3555   }
3556 
3557   return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3558                               output_dim_is_dynamic);
3559 }
3560 
3561 namespace {
3562 
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64_t> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3563 Status ValidateScatterDimensionNumbers(
3564     const Shape& operand_shape, absl::Span<const int64_t> scatter_indices_shape,
3565     const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3566   // Validate update_window_dims in ScatterDimensionNumbers.
3567   if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3568     return InvalidArgument(
3569         "update_window_dims in scatter op must be sorted; got: %s.",
3570         StrJoin(dim_numbers.update_window_dims(), ", "));
3571   }
3572   if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3573       dim_numbers.update_window_dims().end()) {
3574     return InvalidArgument(
3575         "update_window_dims in scatter op must not repeat; got: %s.",
3576         StrJoin(dim_numbers.update_window_dims(), ", "));
3577   }
3578   const int64_t updates_rank = updates_shape.rank();
3579   for (int64_t window_dim : dim_numbers.update_window_dims()) {
3580     if (window_dim < 0 || window_dim >= updates_rank) {
3581       return InvalidArgument(
3582           "Invalid update_window_dims set in scatter op; valid range is [0, "
3583           "%d). got: %d.",
3584           updates_rank, window_dim);
3585     }
3586   }
3587 
3588   // Validate inserted_window_dims in ScatterDimensionNumbers.
3589   if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3590     return InvalidArgument(
3591         "inserted_window_dims in scatter op must be sorted; got: %s.",
3592         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3593   }
3594   if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3595       dim_numbers.inserted_window_dims().end()) {
3596     return InvalidArgument(
3597         "inserted_window_dims in scatter op must not repeat; got: %s.",
3598         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3599   }
3600   for (int64_t inserted_dim : dim_numbers.inserted_window_dims()) {
3601     if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3602       return InvalidArgument(
3603           "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3604           "%d), got: %d.",
3605           operand_shape.dimensions_size(), inserted_dim);
3606     }
3607   }
3608 
3609   // Validate window size.
3610   auto window_size = dim_numbers.update_window_dims_size() +
3611                      dim_numbers.inserted_window_dims_size();
3612   if (window_size != operand_shape.rank()) {
3613     return InvalidArgument(
3614         "Scatter op has window of size %d; doesn't match operand of rank %d.",
3615         window_size, operand_shape.rank());
3616   }
3617 
3618   // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3619   if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3620       scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3621     return InvalidArgument(
3622         "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3623         "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3624         "These two numbers must be equal.",
3625         dim_numbers.scatter_dims_to_operand_dims_size(),
3626         dim_numbers.index_vector_dim(),
3627         scatter_indices_shape[dim_numbers.index_vector_dim()]);
3628   }
3629   for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3630     int64_t scatter_dim_to_operand_dim =
3631         dim_numbers.scatter_dims_to_operand_dims(i);
3632     if (scatter_dim_to_operand_dim < 0 ||
3633         scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3634       return InvalidArgument(
3635           "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3636           "got: %d->%d.",
3637           operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3638     }
3639   }
3640   std::vector<int64_t> sorted_scatter_dims_to_operand_dims(
3641       dim_numbers.scatter_dims_to_operand_dims().begin(),
3642       dim_numbers.scatter_dims_to_operand_dims().end());
3643   absl::c_sort(sorted_scatter_dims_to_operand_dims);
3644   if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3645       sorted_scatter_dims_to_operand_dims.end()) {
3646     return InvalidArgument(
3647         "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3648         "got: %s.",
3649         StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3650   }
3651 
3652   return OkStatus();
3653 }
3654 
3655 }  // namespace
3656 
InferScatterShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3657 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3658     absl::Span<const Shape* const> arg_shapes,
3659     const ProgramShape& to_apply_shape,
3660     const ScatterDimensionNumbers& scatter_dim_numbers) {
3661   const int64_t operand_count = arg_shapes.size() / 2;
3662   if (operand_count * 2 + 1 != arg_shapes.size()) {
3663     return InvalidArgument(
3664         "Invalid argument count of scatter op: Expected %d, saw %d",
3665         operand_count * 2 + 1, arg_shapes.size());
3666   }
3667 
3668   const Shape& scatter_indices_shape = *arg_shapes[operand_count];
3669   TF_RETURN_IF_ERROR(
3670       ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3671   if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3672     return InvalidArgument(
3673         "Scatter indices parameter must be an integral tensor; got %s.",
3674         ShapeUtil::HumanString(scatter_indices_shape));
3675   }
3676   if (scatter_indices_shape.dimensions_size() <
3677           scatter_dim_numbers.index_vector_dim() ||
3678       scatter_dim_numbers.index_vector_dim() < 0) {
3679     return InvalidArgument(
3680         "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3681         " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3682         "is %d.",
3683         scatter_indices_shape.dimensions_size(),
3684         scatter_dim_numbers.index_vector_dim());
3685   }
3686 
3687   std::vector<int64_t> expanded_scatter_indices_shape =
3688       SpanToVector(scatter_indices_shape.dimensions());
3689   if (expanded_scatter_indices_shape.size() ==
3690       scatter_dim_numbers.index_vector_dim()) {
3691     expanded_scatter_indices_shape.push_back(1);
3692   }
3693 
3694   auto operand_shapes = arg_shapes.first(operand_count);
3695   auto updates_shapes = arg_shapes.last(operand_count);
3696   for (int64_t operand_i = 0; operand_i < operand_count; ++operand_i) {
3697     const Shape& operand_shape = *operand_shapes[operand_i];
3698     const Shape& updates_shape = *updates_shapes[operand_i];
3699     TF_RETURN_IF_ERROR(ExpectArray(
3700         operand_shape, absl::StrCat("operand ", operand_i, " of scatter op")));
3701     TF_RETURN_IF_ERROR(ExpectArray(
3702         updates_shape, absl::StrCat("updates ", operand_i, " of scatter op")));
3703 
3704     int64_t inserted_dims_seen = 0;
3705     std::vector<int64_t> max_update_slice_sizes;
3706     const auto dimensions_size = operand_shape.dimensions_size();
3707     max_update_slice_sizes.reserve(dimensions_size);
3708     for (int i = 0; i < dimensions_size; ++i) {
3709       if (inserted_dims_seen <
3710               scatter_dim_numbers.inserted_window_dims_size() &&
3711           scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3712         ++inserted_dims_seen;
3713       } else {
3714         max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3715       }
3716     }
3717     int64_t expected_updates_rank =
3718         expanded_scatter_indices_shape.size() - 1 +
3719         scatter_dim_numbers.update_window_dims_size();
3720     if (updates_shape.rank() != expected_updates_rank) {
3721       return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3722                              expected_updates_rank, updates_shape.rank());
3723     }
3724 
3725     TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3726         operand_shape, expanded_scatter_indices_shape, updates_shape,
3727         scatter_dim_numbers));
3728 
3729     for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3730       auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3731       if (updates_shape.dimensions(update_window_dim) >
3732           max_update_slice_sizes[i]) {
3733         return InvalidArgument(
3734             "Bounds of the window dimensions of updates must not exceed the "
3735             "bounds of the corresponding dimensions of operand. For dimension "
3736             "%d, updates bound is %d, operand bound is %d.",
3737             update_window_dim, updates_shape.dimensions(update_window_dim),
3738             max_update_slice_sizes[i]);
3739       }
3740     }
3741 
3742     int64_t scatter_dims_seen = 0;
3743     for (int64_t i = 0; i < updates_shape.rank(); ++i) {
3744       bool is_update_window_dim =
3745           absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3746       if (is_update_window_dim) {
3747         continue;
3748       }
3749       if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3750         ++scatter_dims_seen;
3751       }
3752       if (updates_shape.dimensions(i) !=
3753           expanded_scatter_indices_shape[scatter_dims_seen]) {
3754         return InvalidArgument(
3755             "Bounds of the scatter dimensions of updates must be same as the "
3756             "bounds of the corresponding dimensions of scatter indices. For "
3757             "scatter dimension %d, updates bound is %d, scatter_indices "
3758             "bound is %d.",
3759             i, updates_shape.dimensions(i),
3760             expanded_scatter_indices_shape[scatter_dims_seen]);
3761       }
3762       ++scatter_dims_seen;
3763     }
3764   }
3765 
3766   // Check if the update computation has a proper shape as a reduction.
3767   absl::InlinedVector<Shape, 1> init_element_shapes;
3768   absl::InlinedVector<const Shape*, 1> init_element_shape_ptrs;
3769   absl::InlinedVector<PrimitiveType, 1> updates_element_types;
3770   init_element_shapes.reserve(operand_count);
3771   init_element_shape_ptrs.reserve(operand_count);
3772   updates_element_types.reserve(operand_count);
3773   for (int64_t i = 0; i < operand_count; ++i) {
3774     init_element_shapes.push_back(
3775         ShapeUtil::MakeShape(operand_shapes[i]->element_type(), {}));
3776     init_element_shape_ptrs.push_back(&init_element_shapes.back());
3777     updates_element_types.push_back(updates_shapes[i]->element_type());
3778   }
3779   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_element_shape_ptrs,
3780                                         updates_element_types, operand_count));
3781 
3782   return operand_count == 1 ? *operand_shapes[0]
3783                             : ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
3784 }
3785 
3786 }  // namespace xla
3787