• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/permutation_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/math/math_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 
43 namespace xla {
44 namespace {
45 
46 using absl::StrFormat;
47 using absl::StrJoin;
48 
49 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)50 bool AllUnique(absl::Span<const int64> slice) {
51   return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
52 }
53 
ExpectArray(const Shape & shape,absl::string_view op_type)54 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
55   if (!shape.IsArray()) {
56     return InvalidArgument("Expected array argument for %s, but got %s.",
57                            string(op_type), ShapeUtil::HumanString(shape));
58   }
59   return Status::OK();
60 }
61 
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64 inputs)62 Status VerifyReducerShape(const ProgramShape& reducer_shape,
63                           absl::Span<const Shape* const> init_value_shapes,
64                           absl::Span<const PrimitiveType> input_element_types,
65                           int64 inputs) {
66   if (reducer_shape.parameters_size() != inputs * 2) {
67     return InvalidArgument(
68         "Reduction function must take %d parameters, but "
69         "takes %d parameter(s).",
70         inputs * 2, reducer_shape.parameters_size());
71   }
72 
73   const Shape& accumulator_shape = reducer_shape.result();
74   std::vector<const Shape*> accumulator_subshapes;
75   if (accumulator_shape.IsArray()) {
76     if (inputs != 1) {
77       return InvalidArgument(
78           "Reduction function must produce a tuple with %d elements, but "
79           "produces a scalar",
80           inputs);
81     }
82     accumulator_subshapes.push_back(&accumulator_shape);
83   } else if (accumulator_shape.IsTuple()) {
84     if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
85       return InvalidArgument(
86           "Reduction function must produce a tuple with %d elements, but has "
87           "%d elements",
88           inputs, ShapeUtil::TupleElementCount(accumulator_shape));
89     }
90     for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
91       accumulator_subshapes.push_back(&element_shape);
92     }
93   } else {
94     return InvalidArgument(
95         "Reduction function must produce a scalar or tuple of scalars, but has "
96         "shape: %s",
97         ShapeUtil::HumanString(accumulator_shape));
98   }
99 
100   for (const Shape* element_shape : accumulator_subshapes) {
101     if (element_shape->rank() != 0) {
102       return InvalidArgument(
103           "Reduction function must return a scalar or tuple of scalars but "
104           "returns shape: %s",
105           ShapeUtil::HumanString(accumulator_shape));
106     }
107   }
108 
109   for (int64 i = 0; i < inputs; ++i) {
110     // Check that the accumulator can be passed in as the first argument.
111     // Note: comparing here and below with Compatible since we don't care about
112     // layout in scalars - see b/26668201 for a longer-term vision.
113     if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
114                                reducer_shape.parameters(i))) {
115       return InvalidArgument(
116           "Reduction function's %d-th parameter shape differs from the "
117           "result shape: %s vs %s",
118           i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
119           ShapeUtil::HumanString(*accumulator_subshapes[i]));
120     }
121     // Check that init_value's shapes are suitable for reducer_shape.
122     if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
123                                                   *init_value_shapes[i])) {
124       return InvalidArgument(
125           "Reduction function's accumulator shape at index %d differs from "
126           "the init_value shape: %s vs %s",
127           i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
128           ShapeUtil::HumanString(*init_value_shapes[i]));
129     }
130     // Check that the inputs can be passed in as the non-accumulator arguments.
131     const Shape input_element_shape =
132         ShapeUtil::MakeShape(input_element_types[i], {});
133     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
134             input_element_shape, reducer_shape.parameters(inputs + i))) {
135       return InvalidArgument(
136           "Reduction function's %d-th parameter shape differs from the "
137           "input type element type: %s vs %s",
138           inputs + i,
139           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
140           ShapeUtil::HumanString(input_element_shape));
141     }
142     // Check that the accumulator and inputs to the reducer function match.
143     // If the accumulator is scalar, it must have the same type as the inputs
144     // (up to fp precision). If it is a tuple, then the k-th element of the
145     // tuple must have the same type as the K-th input (again, up to fp
146     // precision.)
147     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
148             *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
149       return InvalidArgument(
150           "Reduction function's %d-th parameter shape must "
151           "match the result shape, but got %s vs %s.",
152           inputs + i,
153           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
154           ShapeUtil::HumanString(*accumulator_subshapes[i]));
155     }
156   }
157 
158   return Status::OK();
159 }
160 
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)161 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
162                                        const Window& window,
163                                        PrimitiveType element_type,
164                                        bool allow_negative_padding) {
165   if (window.dimensions_size() != base_shape.rank()) {
166     return InvalidArgument(
167         "Window has dimension %d but base shape has dimension %d.",
168         window.dimensions_size(), base_shape.rank());
169   }
170 
171   std::vector<int64> output_dimensions(window.dimensions_size());
172   std::vector<bool> output_is_dynamic(window.dimensions_size());
173   for (int64 i = 0; i < window.dimensions_size(); ++i) {
174     const auto& dim = window.dimensions(i);
175     if (dim.size() <= 0) {
176       return InvalidArgument("Window %s has a non-positive dimension.",
177                              window.DebugString());
178     }
179     if (dim.stride() <= 0) {
180       return InvalidArgument("Window %s has a non-positive stride.",
181                              window.DebugString());
182     }
183     if (!allow_negative_padding && dim.padding_low() < 0) {
184       return InvalidArgument("Window %s has a negative low padding.",
185                              window.DebugString());
186     }
187     if (!allow_negative_padding && dim.padding_high() < 0) {
188       return InvalidArgument("Window %s has a negative high padding.",
189                              window.DebugString());
190     }
191     if (dim.base_dilation() < 1) {
192       return InvalidArgument(
193           "Window %s has a non-positive base area dilation factor.",
194           window.DebugString());
195     }
196     if (dim.window_dilation() < 1) {
197       return InvalidArgument(
198           "Window %s has a non-positive window dilation factor.",
199           window.DebugString());
200     }
201 
202     const int64 dilated_base = window_util::DilatedBound(
203         ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
204     const int64 padded_dilated_base =
205         dim.padding_low() + dilated_base + dim.padding_high();
206     const int64 dilated_window =
207         window_util::DilatedBound(dim.size(), dim.window_dilation());
208 
209     output_dimensions[i] = window_util::StridedBound(
210         padded_dilated_base, dilated_window, dim.stride());
211     output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
212   }
213 
214   return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
215                                        output_is_dynamic);
216 }
217 
MaybeUpcast(PrimitiveType from_type,absl::optional<PrimitiveType> preferred_element_type)218 StatusOr<PrimitiveType> MaybeUpcast(
219     PrimitiveType from_type,
220     absl::optional<PrimitiveType> preferred_element_type) {
221   if (!preferred_element_type.has_value() ||
222       *preferred_element_type == from_type) {
223     return from_type;
224   }
225   if (primitive_util::IsIntegralType(from_type) !=
226       primitive_util::IsIntegralType(*preferred_element_type)) {
227     return InvalidArgument(
228         "`preferred_element_type` and the original type must both be integral "
229         "or both be floating point.");
230   }
231   if (!primitive_util::IsSignedIntegralType(from_type) !=
232       !primitive_util::IsSignedIntegralType(*preferred_element_type)) {
233     return InvalidArgument(
234         "`preferred_element_type` must have the same signedness as the "
235         "original type.");
236   }
237   if (!primitive_util::IsFloatingPointType(from_type) &&
238       primitive_util::BitWidth(*preferred_element_type) <
239           primitive_util::BitWidth(from_type)) {
240     return InvalidArgument(
241         "`preferred_element_type` must not be narrower than the original "
242         "type.");
243   }
244   return *preferred_element_type;
245 }
246 
247 }  // namespace
248 
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)249 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
250     HloOpcode opcode, const HloInstruction* operand) {
251   return InferUnaryOpShape(opcode, operand->shape());
252 }
253 
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)254 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
255     HloOpcode opcode, const Shape& shape) {
256   // There is no copy operation at the proto level, so handle copy explicitly.
257   // A domain shape is the same as the input one.
258   if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
259     return shape;
260   }
261 
262   TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
263 
264   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
265   switch (opcode) {
266     case HloOpcode::kFloor:
267     case HloOpcode::kCeil:
268     case HloOpcode::kRoundNearestAfz:
269       if (!ShapeUtil::ElementIsFloating(shape)) {
270         return InvalidArgument(
271             "Expected element type in shape to be floating for %s operation; "
272             "got %s.",
273             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
274       }
275       return shape;
276     case HloOpcode::kCos:
277     case HloOpcode::kSin:
278     case HloOpcode::kExp:
279     case HloOpcode::kExpm1:
280     case HloOpcode::kLog:
281     case HloOpcode::kLog1p:
282     case HloOpcode::kLogistic:
283     case HloOpcode::kRsqrt:
284     case HloOpcode::kSqrt:
285     case HloOpcode::kCbrt:
286     case HloOpcode::kTanh:
287       if (!ShapeUtil::ElementIsFloating(shape) &&
288           !ShapeUtil::ElementIsComplex(shape)) {
289         return InvalidArgument(
290             "Expected element type in shape to be floating or complex for %s "
291             "operation; got %s.",
292             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
293       }
294       return shape;
295     case HloOpcode::kReal:
296     case HloOpcode::kImag:
297       if (ShapeUtil::ElementIsComplex(shape)) {
298         return ShapeUtil::ComplexComponentShape(shape);
299       } else if (ShapeUtil::ElementIsFloating(shape)) {
300         return shape;
301       } else {
302         return InvalidArgument(
303             "Expected element type in shape to be floating or complex for "
304             "%s operation; got %s.",
305             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
306       }
307     case HloOpcode::kAbs:
308       if (ShapeUtil::ElementIsComplex(shape)) {
309         return ShapeUtil::ChangeElementType(
310             shape, primitive_util::ComplexComponentType(shape.element_type()));
311       } else if (ShapeUtil::ElementIsSigned(shape)) {
312         return shape;
313       } else {
314         return InvalidArgument(
315             "Expected element type in shape to be floating or complex for "
316             "%s operation; got %s.",
317             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
318       }
319     case HloOpcode::kClz:
320       if (!ShapeUtil::ElementIsIntegral(shape)) {
321         return InvalidArgument(
322             "Expected an integral element type in argument to Clz "
323             "operation; got %s.",
324             PrimitiveType_Name(shape.element_type()));
325       }
326       return shape;
327     case HloOpcode::kNegate:
328       if (!ShapeUtil::ElementIsIntegral(shape) &&
329           !ShapeUtil::ElementIsFloating(shape) &&
330           !ShapeUtil::ElementIsComplex(shape)) {
331         return InvalidArgument(
332             "Expected element type in shape to be integral, floating or "
333             "complex for %s operation; got %s.",
334             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
335       }
336       return shape;
337     case HloOpcode::kPopulationCount:
338       if (!ShapeUtil::ElementIsIntegral(shape)) {
339         return InvalidArgument(
340             "Expected an integral element type in argument to PopulationCount "
341             "operation; got %s.",
342             PrimitiveType_Name(shape.element_type()));
343       }
344       return shape;
345     case HloOpcode::kSign:
346       if (!ShapeUtil::ElementIsSigned(shape) &&
347           !ShapeUtil::ElementIsComplex(shape)) {
348         return InvalidArgument(
349             "Expected element type in shape to be signed or complex for "
350             "%s operation; got %s.",
351             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
352       }
353       return shape;
354 
355     case HloOpcode::kNot:
356       if (shape.element_type() != PRED &&
357           !primitive_util::IsIntegralType(shape.element_type())) {
358         return InvalidArgument(
359             "Expected pred or an integral element type in argument to Not "
360             "operation; got %s.",
361             PrimitiveType_Name(shape.element_type()));
362       }
363       return shape;
364 
365     case HloOpcode::kIsFinite:
366       if (!ShapeUtil::ElementIsFloating(shape)) {
367         return InvalidArgument(
368             "Expected element type in shape to be floating "
369             "point for IsFinite "
370             "operation; got %s.",
371             PrimitiveType_Name(shape.element_type()));
372       }
373       return ShapeUtil::ChangeElementType(shape, PRED);
374 
375     default:
376       return InvalidArgument(
377           "Unknown operation for unary shape inference: \"%s\".",
378           HloOpcodeString(opcode));
379   }
380 }
381 
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64 dimension)382 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
383     absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
384   if (arg_shapes.empty()) {
385     return InvalidArgument("Concatenate expects at least one argument.");
386   }
387   if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
388     return InvalidArgument("Concatenate dimension out of bounds: %d.",
389                            dimension);
390   }
391   const Shape* arg_shape = nullptr;
392   PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
393   for (const Shape* shape : arg_shapes) {
394     TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
395     if (!arg_shape) {
396       arg_shape = shape;
397       element_type = arg_shape->element_type();
398       continue;
399     }
400     if (arg_shape->rank() != shape->rank()) {
401       return InvalidArgument(
402           "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
403           "(%s).",
404           arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
405           ShapeUtil::HumanString(*shape));
406     }
407     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
408       return InvalidArgument(
409           "Cannot concatenate arrays with different element types: %s vs %s.",
410           PrimitiveType_Name(arg_shape->element_type()),
411           PrimitiveType_Name(shape->element_type()));
412     }
413     for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
414          ++dimension_number) {
415       if (arg_shape->dimensions(dimension_number) !=
416           shape->dimensions(dimension_number)) {
417         if (dimension_number == dimension) {
418           continue;  // It's okay to differ in the dimension we're
419                      // concatenating.
420         }
421         return InvalidArgument(
422             "Cannot concatenate arrays that differ in dimensions other than "
423             "the one being concatenated (the other array dimensions must be "
424             "the same): %s vs %s in dimension %d.",
425             ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
426             dimension);
427       }
428     }
429     element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
430   }
431 
432   std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
433                                     arg_shape->dimensions().end());
434   for (size_t i = 1; i < arg_shapes.size(); ++i) {
435     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
436   }
437 
438   Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
439 
440   // Set dynamic dimensions if any input has dynamic dimension.
441   for (const Shape* shape : arg_shapes) {
442     for (int64 i = 0; i < shape->dimensions_size(); ++i) {
443       if (shape->is_dynamic_dimension(i)) {
444         result.set_dynamic_dimension(i, true);
445       }
446     }
447   }
448   return result;
449 }
450 
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)451 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
452     const Shape& operand_shape, PrimitiveType new_element_type) {
453   auto old_element_type = operand_shape.element_type();
454   if (primitive_util::IsComplexType(old_element_type) &&
455       !primitive_util::IsComplexType(new_element_type)) {
456     return Unimplemented(
457         "Conversion from complex to real type %s => %s is not implemented.",
458         ShapeUtil::HumanString(operand_shape),
459         PrimitiveType_Name(new_element_type));
460   }
461   if (!operand_shape.IsArray() ||
462       !primitive_util::IsArrayType(new_element_type)) {
463     // Note: we may want to support tuple conversions via this operation in the
464     // future, by recursing into the tuple elements to check all sub-conversions
465     // are valid. For now we just reject them, though.
466     return InvalidArgument(
467         "Convert does not allow non-arrays, so cannot convert from %s to %s.",
468         ShapeUtil::HumanString(operand_shape),
469         PrimitiveType_Name(new_element_type));
470   }
471 
472   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
473 }
474 
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)475 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
476     const Shape& operand_shape, PrimitiveType new_element_type) {
477   auto old_element_type = operand_shape.element_type();
478   if (primitive_util::IsComplexType(old_element_type) !=
479       primitive_util::IsComplexType(new_element_type)) {
480     return InvalidArgument("Conversion from complex to real type %s => %s.",
481                            ShapeUtil::HumanString(operand_shape),
482                            PrimitiveType_Name(new_element_type));
483   }
484   if (!operand_shape.IsArray() ||
485       !primitive_util::IsArrayType(new_element_type)) {
486     // Note: we may want to support tuple conversions via this operation in the
487     // future, by recursing into the tuple elements to check all sub-conversions
488     // are valid. For now we just reject them, though.
489     return InvalidArgument(
490         "Cannot convert from or to tuple type; requested conversion: %s => %s.",
491         ShapeUtil::HumanString(operand_shape),
492         PrimitiveType_Name(new_element_type));
493   }
494   if (primitive_util::BitWidth(old_element_type) !=
495       primitive_util::BitWidth(new_element_type)) {
496     return InvalidArgument(
497         "Cannot bitcast types with different bit-widths: %s => %s.",
498         PrimitiveType_Name(old_element_type),
499         PrimitiveType_Name(new_element_type));
500   }
501 
502   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
503 }
504 
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)505 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
506     const Shape& operand_shape, const int exponent_bits,
507     const int mantissa_bits) {
508   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
509     return InvalidArgument(
510         "Expected element type in shape to be floating point for "
511         "ReducePrecision operation; got %s.",
512         PrimitiveType_Name(operand_shape.element_type()));
513   }
514   if (exponent_bits < 1) {
515     // One exponent bit is necessary to distinguish 0 from infinity.  Having
516     // no exponent bits doesn't produce a sensible number, so we require at
517     // least one.
518     return InvalidArgument("Expected exponent_bits >= 1; got %d.",
519                            exponent_bits);
520   }
521   if (mantissa_bits < 0) {
522     // A number with no mantissa bits is still meaningful, however.
523     return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
524                            mantissa_bits);
525   }
526   return operand_shape;
527 }
528 
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)529 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
530     const Shape& operand_shape, const Shape& padding_value_shape,
531     const PaddingConfig& padding_config) {
532   if (!operand_shape.IsArray()) {
533     return InvalidArgument(
534         "Pad operation does not support tuple-shape operands.");
535   }
536   if (!ShapeUtil::IsScalar(padding_value_shape)) {
537     return InvalidArgument(
538         "Pad operation does not support non-scalar padding values.");
539   }
540   if (operand_shape.rank() != padding_config.dimensions_size()) {
541     return InvalidArgument(
542         "The rank of the operand and the padding configuration do not match: "
543         "%s vs %s.",
544         ShapeUtil::HumanString(operand_shape),
545         padding_config.ShortDebugString());
546   }
547   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
548                                                      padding_value_shape)) {
549     return InvalidArgument(
550         "The element types of the operands to Pad do not match.");
551   }
552   if (absl::c_any_of(padding_config.dimensions(),
553                      [](const PaddingConfig::PaddingConfigDimension& p) {
554                        return p.interior_padding() < 0;
555                      })) {
556     return InvalidArgument("Interior padding cannot be negative: %s",
557                            padding_config.ShortDebugString());
558   }
559 
560   if (!padding_value_shape.is_static()) {
561     return InvalidArgument("Dynamic padding value is not supported");
562   }
563 
564   std::vector<int64> dimensions(operand_shape.rank());
565   std::vector<bool> is_dynamic(operand_shape.rank());
566   for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
567     const auto& p = padding_config.dimensions(i);
568     dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
569                     p.edge_padding_high() +
570                     std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
571                         p.interior_padding();
572     if (dimensions[i] < 0) {
573       return InvalidArgument("Padding result in negative size for dimension %d",
574                              i);
575     }
576     is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
577   }
578 
579   return ShapeUtil::MakeShape(
580       ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
581       dimensions, is_dynamic);
582 }
583 
584 // Current DotDimensionNumbers Requirements:
585 //
586 // Contracting Dimensions:
587 // *) Same number of contracting dimensions on both lhs and rhs.
588 // *) Contracting dimension size must be the same on both lhs and rhs.
589 //
590 // Batch Dimensions:
591 // *) Same number of batch dimensions on both lhs and rhs.
592 // *) Same batch dimension sizes on both lhs and rhs.
593 //
594 
595 namespace {
596 
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)597 Status ValidateDotDimensionNumbers(
598     const Shape& lhs, const Shape& rhs,
599     const DotDimensionNumbers& dimension_numbers) {
600   // Check that dimension numbers are in range.
601   auto dims_in_range = [](const int64 rank,
602                           absl::Span<const int64> contracting_dims,
603                           absl::Span<const int64> batch_dims) -> bool {
604     auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
605     return absl::c_all_of(contracting_dims, in_range) &&
606            absl::c_all_of(batch_dims, in_range);
607   };
608 
609   absl::Span<const int64> lhs_contracting_dimensions =
610       AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
611   absl::Span<const int64> rhs_contracting_dimensions =
612       AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
613   absl::Span<const int64> lhs_batch_dimensions =
614       AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
615   absl::Span<const int64> rhs_batch_dimensions =
616       AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
617 
618   if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
619                      lhs_batch_dimensions) ||
620       !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
621                      rhs_batch_dimensions)) {
622     return InvalidArgument("A dimension number is out of range in Dot: %s.",
623                            dimension_numbers.DebugString());
624   }
625 
626   // Check that dimension numbers are unique.
627   auto dims_unique = [](absl::Span<const int64> contracting_dims,
628                         absl::Span<const int64> batch_dims) -> bool {
629     absl::flat_hash_set<int64> dim_set;
630     auto is_unique = [&dim_set](int64 i) -> bool {
631       return dim_set.insert(i).second;
632     };
633     return absl::c_all_of(contracting_dims, is_unique) &&
634            absl::c_all_of(batch_dims, is_unique);
635   };
636 
637   if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
638       !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
639     return InvalidArgument("A dimension number is not unique in Dot: %s.",
640                            dimension_numbers.DebugString());
641   }
642 
643   return Status::OK();
644 }
645 
646 }  // namespace
647 
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers,absl::optional<PrimitiveType> preferred_element_type)648 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
649     const Shape& lhs, const Shape& rhs,
650     const DotDimensionNumbers& dimension_numbers,
651     absl::optional<PrimitiveType> preferred_element_type) {
652   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
653   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
654 
655   auto fail = [lhs, rhs](const string& addendum) -> Status {
656     string message =
657         StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
658                   ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
659     if (!addendum.empty()) {
660       message += " " + addendum;
661     }
662     return InvalidArgument("%s", message);
663   };
664 
665   // Validate basic properties of dot dimension numbers.
666   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
667 
668   // Check that number of contracting dimensions match.
669   if (dimension_numbers.lhs_contracting_dimensions_size() !=
670       dimension_numbers.rhs_contracting_dimensions_size()) {
671     return fail(
672         "Must specify the same number of contracting dimensions for lhs and "
673         "rhs.");
674   }
675   // Check that contracting dimension sizes match.
676   for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
677        ++i) {
678     const int64 lhs_contracting_dimension =
679         dimension_numbers.lhs_contracting_dimensions(i);
680     const int64 rhs_contracting_dimension =
681         dimension_numbers.rhs_contracting_dimensions(i);
682     if (lhs.dimensions(lhs_contracting_dimension) !=
683         rhs.dimensions(rhs_contracting_dimension)) {
684       return fail("Contracting dimension sizes do not match.");
685     }
686   }
687 
688   // Check that number of batch dimensions match.
689   if (dimension_numbers.lhs_batch_dimensions_size() !=
690       dimension_numbers.rhs_batch_dimensions_size()) {
691     return fail("Must the same number of batch dimensions for lhs and rhs.");
692   }
693 
694   // Check that batch dimension numbers and sizes match.
695   for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
696     if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
697         rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
698       return fail("Batch dimension sizes must match for lhs/rhs.");
699     }
700   }
701 
702   // The ranks of lhs and rhs are decremented by 1 respectively due to the
703   // contraction, and added for the rank of the result. When an input tensor is
704   // a scalar, its contribution to the rank of the result is 0.
705   // Generate the result dimensions in order, rhs dimensions followed by lhs
706   // dimensions except the contracted and batch dimensions.
707   std::vector<int64> dimensions;
708   std::vector<bool> is_dynamic;
709   for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
710     dimensions.push_back(lhs.dimensions(lhs_dim));
711     is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
712   }
713   for (int64 i = 0; i < lhs.rank(); i++) {
714     if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
715                                i) &&
716         !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
717       dimensions.push_back(lhs.dimensions(i));
718       is_dynamic.push_back(lhs.is_dynamic_dimension(i));
719     }
720   }
721   for (int64 i = 0; i < rhs.rank(); i++) {
722     if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
723                                i) &&
724         !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
725       dimensions.push_back(rhs.dimensions(i));
726       is_dynamic.push_back(rhs.is_dynamic_dimension(i));
727     }
728   }
729   TF_ASSIGN_OR_RETURN(
730       PrimitiveType type,
731       MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
732                   preferred_element_type));
733   Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
734 
735   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
736   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
737   return result;
738 }
739 
740 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)741 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
742                                                        const Shape& lhs,
743                                                        const Shape& rhs) {
744   TF_RET_CHECK(lhs.rank() == rhs.rank());
745 
746   // The shapes have to be compatible. That is, if some dimension d has a
747   // different size in the two shapes, one of them has to be 1 (a "degenerate"
748   // dimension). In that case, the output shape has the non-1 dimension size
749   // from the lhs/rhs pair in every index.
750   std::vector<int64> output_dimensions(lhs.rank());
751   std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
752   for (int64 i = 0; i < lhs.rank(); ++i) {
753     if (lhs.dimensions(i) == rhs.dimensions(i)) {
754       output_dimensions[i] = lhs.dimensions(i);
755     } else if (lhs.dimensions(i) == 1) {
756       output_dimensions[i] = rhs.dimensions(i);
757     } else if (rhs.dimensions(i) == 1) {
758       output_dimensions[i] = lhs.dimensions(i);
759     } else {
760       return InvalidArgument(
761           "Binary op %s with incompatible shapes: %s and %s.",
762           HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
763           ShapeUtil::HumanString(rhs));
764     }
765   }
766 
767   // Merge dynamic dimensions from two shapes.
768   for (int64 i = 0; i < rhs.rank(); ++i) {
769     if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
770       output_dimensions_is_dynamic[i] = true;
771     }
772   }
773 
774   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
775                               output_dimensions, output_dimensions_is_dynamic);
776 }
777 
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)778 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
779     const Shape& smaller_shape, const Shape& larger_shape,
780     absl::Span<const int64> broadcast_dimensions) {
781   if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
782     // Reject "magic" inference for binops on different shapes, requiring
783     // the user to provide an explicit broadcast dimension in this case.
784     // See b/25177275 for more details.
785     return InvalidArgument("Automatic shape inference not supported: %s and %s",
786                            ShapeUtil::HumanString(smaller_shape),
787                            ShapeUtil::HumanString(larger_shape));
788   } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
789     return InvalidArgument(
790         "Size of broadcast_dimensions has to match lower-rank operand's "
791         "rank; "
792         " lower-rank operand's rank is %d, size of broadcast_dimensions is "
793         "%u.",
794         smaller_shape.rank(), broadcast_dimensions.size());
795   }
796 
797   // broadcast_dimensions is a sequence of dimensions; its length is equal to
798   // the rank of the lower-rank operand. The lower-rank operand's dimensions
799   // have to be compatible with the higher-rank operand's dimensions at indices
800   // specified by broadcast_dimensions. Here compatible means the dimension
801   // sizes are equal or in one of the shapes the dimension size is
802   // one. Examples:
803   //
804   // smaller_shape   larger_shape   broadcast_dimensions   output_shape
805   //   []              [2, 3]          {}                    [2, 3]
806   //   [3]             [4, 3]          {1}                   [4, 3]
807   //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
808   //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
809   //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
810   //
811   // The column output_shape may not be the final shape of the XLA
812   // operation. After the "InDim" broadcasting implemented in this function
813   // expands the rank, degenerate-dimension broadcasting (implemented in
814   // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
815   // up to match the dimension size of the other operand. For example, consider
816   // the row in the table above with a smaller_shape of [2, 1]. The shape
817   // returned by this function is [2, 3, 1] (output_shape) however, the result
818   // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
819   // broadcasting.
820   //
821   // Invalid broadcasts:
822   //
823   // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
824   // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
825   //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
826   //   in broadcast_dimensions.
827   //
828   // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
829   // Reason: Dimension one of larger_shape (size 3) is not compatible with
830   //   dimension zero of smaller_shape(size 2)
831 
832   // The output shape is initially the larger_shape. Sizes of dimensions
833   // specified in broadcast_dimensions are then changed to match the
834   // corresponding dimension size in smaller_shape.
835   Shape output_shape(larger_shape);
836   output_shape.set_element_type(
837       ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
838 
839   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
840     int64 dimension_to_match = broadcast_dimensions.at(i);
841     if (dimension_to_match < 0) {
842       return InvalidArgument(
843           "Broadcast dimension number (%d) cannot be negative.",
844           dimension_to_match);
845     }
846     if (dimension_to_match >= larger_shape.dimensions_size()) {
847       return InvalidArgument(
848           "Broadcast dimension number (%d) too large; higher-rank "
849           "operand has rank %d.",
850           dimension_to_match, larger_shape.dimensions_size());
851     }
852     int64 small_dimension_size = smaller_shape.dimensions(i);
853     int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
854     bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
855     bool large_is_dynamic =
856         larger_shape.is_dynamic_dimension(dimension_to_match);
857     // Dimension sizes must be compatible: match or be degenerate (degenerate
858     // case is handled by degenerate dimension broadcasting which occurs after
859     // InDim broadcasting).
860     if (small_dimension_size != large_dimension_size &&
861         small_dimension_size != 1 && large_dimension_size != 1) {
862       return InvalidArgument(
863           "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
864           small_dimension_size, large_dimension_size,
865           ShapeUtil::HumanString(smaller_shape),
866           ShapeUtil::HumanString(larger_shape));
867     }
868     if (small_is_dynamic != large_is_dynamic) {
869       if (small_dimension_size == large_dimension_size ||
870           (small_dimension_size == 1 && !small_is_dynamic) ||
871           (large_dimension_size == 1 && !large_is_dynamic)) {
872         // Do nothing. It's OK when the size-1 dimension is not static.
873       } else {
874         return InvalidArgument(
875             "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
876             ShapeUtil::HumanString(smaller_shape),
877             ShapeUtil::HumanString(larger_shape));
878       }
879     }
880     // Make sure the broadcast dimensions are listed in a strictly increasing
881     // order.
882     if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
883       return InvalidArgument(
884           "Broadcast dimensions order is wrong: %d comes after %d.",
885           dimension_to_match, broadcast_dimensions.at(i - 1));
886     }
887 
888     output_shape.set_dimensions(dimension_to_match, small_dimension_size);
889     output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
890   }
891 
892   return output_shape;
893 }
894 
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)895 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
896     HloOpcode operation, const Shape& lhs, const Shape& rhs,
897     absl::Span<const int64> broadcast_dimensions) {
898   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
899   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
900 
901   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
902     return InvalidArgument(
903         "Binary op %s with different element types: %s and %s.",
904         HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
905         ShapeUtil::HumanString(rhs));
906   }
907 
908   if (lhs.rank() == rhs.rank()) {
909     std::vector<int64> identity_dims(lhs.rank());
910     std::iota(identity_dims.begin(), identity_dims.end(), 0);
911     if (!broadcast_dimensions.empty() &&
912         broadcast_dimensions != identity_dims) {
913       return InvalidArgument(
914           "Broadcast dimensions field must either be not set or be the "
915           "identity on binary operations with operands of the same rank.");
916     }
917   }
918 
919   if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
920     // If the shapes are the same other than layout, the output shape is the
921     // same (elementwise op).
922     Shape result = ShapeUtil::ChangeElementType(
923         lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
924 
925     for (int64 i = 0; i < rhs.rank(); ++i) {
926       if (rhs.is_dynamic_dimension(i)) {
927         result.set_dynamic_dimension(i, true);
928       }
929     }
930 
931     return result;
932 
933   } else if (lhs.rank() == rhs.rank()) {
934     return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
935   } else {
936     // Ranks do not match, so perform InDim broadcasting using
937     // broadcast_dimensions. Scalar broadcasting is a special case of this.
938     const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
939     const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
940 
941     // After InDim broadcasting, perform degenerate dimensions broadcasting.
942     TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
943                         InferInDimBroadcastShape(smaller_shape, larger_shape,
944                                                  broadcast_dimensions));
945 
946     return InferDegenerateDimensionBroadcastShape(
947         operation, indim_broadcast_shape, larger_shape);
948   }
949 }
950 
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)951 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
952     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
953   return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
954                             /*broadcast_dimensions=*/{});
955 }
956 
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)957 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
958     HloOpcode opcode, const Shape& lhs, const Shape& rhs,
959     absl::Span<const int64> broadcast_dimensions) {
960   VLOG(2) << StrFormat(
961       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
962       HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
963       ShapeUtil::HumanStringWithLayout(rhs),
964       StrJoin(broadcast_dimensions, ", "));
965 
966   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
967   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
968 
969   TF_RETURN_IF_ERROR(ExpectArray(
970       lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
971   TF_RETURN_IF_ERROR(ExpectArray(
972       rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
973   switch (opcode) {
974     case HloOpcode::kAdd:
975     case HloOpcode::kMaximum:
976     case HloOpcode::kMinimum:
977     case HloOpcode::kMultiply:
978       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
979                                            broadcast_dimensions);
980 
981     case HloOpcode::kSubtract:
982     case HloOpcode::kAtan2:
983     case HloOpcode::kPower:
984     case HloOpcode::kDivide:
985     case HloOpcode::kRemainder:
986     case HloOpcode::kShiftLeft:
987     case HloOpcode::kShiftRightArithmetic:
988     case HloOpcode::kShiftRightLogical:
989       if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
990         return InvalidArgument(
991             "Expected element type in shape to be arithmetic type for "
992             "operation %s; got PRED.",
993             HloOpcodeString(opcode));
994       }
995       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
996                                            broadcast_dimensions);
997 
998     case HloOpcode::kComplex: {
999       if (!ShapeUtil::ElementIsFloating(lhs)) {
1000         return InvalidArgument(
1001             "Expected element type in shape to be floating for complex compose "
1002             "operation; got %s.",
1003             PrimitiveType_Name(lhs.element_type()));
1004       }
1005       TF_ASSIGN_OR_RETURN(const Shape& shape,
1006                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1007                                                         broadcast_dimensions));
1008       if (lhs.element_type() == F32 && rhs.element_type() == F32) {
1009         return ShapeUtil::ChangeElementType(shape, C64);
1010       } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
1011         return ShapeUtil::ChangeElementType(shape, C128);
1012       } else {
1013         return Unimplemented("Complex component type is not implemented.");
1014       }
1015     }
1016     case HloOpcode::kAnd:
1017     case HloOpcode::kOr:
1018     case HloOpcode::kXor:
1019       if (lhs.element_type() != PRED &&
1020           !primitive_util::IsIntegralType(lhs.element_type())) {
1021         return InvalidArgument(
1022             "Expected pred or integral type in argument to and/or operation; "
1023             "got %s.",
1024             PrimitiveType_Name(lhs.element_type()));
1025       }
1026       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1027                                            broadcast_dimensions);
1028     case HloOpcode::kCompare: {
1029       TF_ASSIGN_OR_RETURN(const Shape& shape,
1030                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1031                                                         broadcast_dimensions));
1032       return ShapeUtil::ChangeElementType(shape, PRED);
1033     }
1034     default:
1035       return Unimplemented(
1036           "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1037           HloOpcodeString(opcode), lhs.ShortDebugString(),
1038           rhs.ShortDebugString());
1039   }
1040 }
1041 
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1042 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1043     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1044     const HloInstruction* ehs) {
1045   return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1046 }
1047 
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1048 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1049     HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1050   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1051   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1052   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1053   switch (opcode) {
1054     case HloOpcode::kClamp:
1055       return InferClampShape(lhs, rhs, ehs);
1056     case HloOpcode::kSelect:
1057       return InferSelectShape(lhs, rhs, ehs);
1058     case HloOpcode::kTupleSelect:
1059       return InferTupleSelectShape(lhs, rhs, ehs);
1060     default:
1061       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1062   }
1063 }
1064 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1065 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1066     HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1067   std::vector<const Shape*> operand_shapes;
1068   operand_shapes.reserve(operands.size());
1069   for (const HloInstruction* operand : operands) {
1070     operand_shapes.push_back(&operand->shape());
1071   }
1072   return InferVariadicOpShape(opcode, operand_shapes);
1073 }
1074 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1075 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1076     HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1077   for (const Shape* shape : operand_shapes) {
1078     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1079   }
1080   switch (opcode) {
1081     case HloOpcode::kTuple: {
1082       Shape result = ShapeUtil::MakeTupleShape({});
1083       result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1084       for (const Shape* shape : operand_shapes) {
1085         ShapeUtil::AppendShapeToTuple(*shape, &result);
1086       }
1087       return result;
1088     }
1089     case HloOpcode::kSort: {
1090       if (operand_shapes.size() == 1) {
1091         return *operand_shapes[0];
1092       } else {
1093         for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
1094           if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1095                                          *operand_shapes[operand])) {
1096             return InvalidArgument(
1097                 "Sort keys and values dimensions must match. "
1098                 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1099                 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1100                 ShapeUtil::HumanString(*operand_shapes[operand]));
1101           }
1102         }
1103         std::vector<Shape> operand_shape_values;
1104         for (const Shape* operand_shape : operand_shapes) {
1105           operand_shape_values.push_back(*operand_shape);
1106         }
1107         return ShapeUtil::MakeTupleShape(operand_shape_values);
1108       }
1109       return InvalidArgument("Unexpected number of operands for sort");
1110     }
1111     default:
1112       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1113   }
1114 }
1115 
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1116 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1117     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1118     absl::Span<const int64> dimensions) {
1119   if (arg_shapes.empty()) {
1120     return InvalidArgument("Map expects at least one argument.");
1121   }
1122 
1123   // All arguments must have the same shape.
1124   const Shape* arg_shape = arg_shapes[0];
1125   for (size_t i = 1; i < arg_shapes.size(); ++i) {
1126     TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1127 
1128     if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1129       continue;
1130     }
1131     if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1132                                                       *arg_shape)) {
1133       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1134         continue;
1135       }
1136       if (ShapeUtil::IsScalar(*arg_shape)) {
1137         arg_shape = arg_shapes[i];
1138         continue;
1139       }
1140     }
1141 
1142     std::vector<string> pieces;
1143     for (const Shape* shape : arg_shapes) {
1144       pieces.push_back(ShapeUtil::HumanString(*shape));
1145     }
1146     return InvalidArgument(
1147         "Map operation requires all operands to have the same shape; got: "
1148         "%s.",
1149         StrJoin(pieces, ", "));
1150   }
1151 
1152   // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1153   // only support mapping across all dimensions: i.e. scalar map functions).
1154   if (dimensions.size() != arg_shape->dimensions_size()) {
1155     return InvalidArgument(
1156         "Map applied to a subset of dimensions currently not supported: "
1157         "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1158         arg_shape->dimensions_size(), dimensions.size());
1159   }
1160 
1161   // Check that requested map dimensions numbers are monotonically increasing.
1162   for (int i = 0; i < dimensions.size(); ++i) {
1163     if (dimensions[i] != i) {
1164       return InvalidArgument(
1165           "Map requires monotonically increasing dimension numbers; got: %s.",
1166           StrJoin(dimensions, ", "));
1167     }
1168   }
1169 
1170   // The applied function's arity equals the number of arguments.
1171   if (arg_shapes.size() != to_apply.parameters_size()) {
1172     return InvalidArgument(
1173         "Map applied function arity must match number of arguments; got: "
1174         "arity: %d, arguments: %u.",
1175         to_apply.parameters_size(), arg_shapes.size());
1176   }
1177 
1178   // The parameters should all be scalars, and the output too.
1179   const Shape& output_shape = to_apply.result();
1180   if (!ShapeUtil::IsScalar(output_shape)) {
1181     return InvalidArgument(
1182         "Mapped computation's result has to be a scalar; got: %s.",
1183         ShapeUtil::HumanString(output_shape));
1184   }
1185 
1186   for (int i = 0; i < to_apply.parameters_size(); ++i) {
1187     const Shape& parameter_shape = to_apply.parameters(i);
1188 
1189     if (!ShapeUtil::IsScalar(parameter_shape)) {
1190       return InvalidArgument(
1191           "Mapped computation's parameter has to be a scalar; "
1192           "got parameter %d shape: %s.",
1193           i, ShapeUtil::HumanString(parameter_shape));
1194     }
1195 
1196     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1197                                                        *arg_shape)) {
1198       return InvalidArgument(
1199           "Mapped computation's parameter type has to match argument element "
1200           "type; got parameter %d shape: %s, argument shape: %s.",
1201           i, ShapeUtil::HumanString(parameter_shape),
1202           ShapeUtil::HumanString(*arg_shape));
1203     }
1204   }
1205 
1206   return ShapeUtil::MakeShape(output_shape.element_type(),
1207                               AsInt64Slice(arg_shape->dimensions()));
1208 }
1209 
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1210 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1211     const Shape& operand_shape, const Shape& scale_shape,
1212     const Shape& offset_shape, int64 feature_index) {
1213   TF_RETURN_IF_ERROR(
1214       ExpectArray(operand_shape, "operand of batch norm training"));
1215   TF_RETURN_IF_ERROR(
1216       ExpectArray(offset_shape, "offset input of batch norm training"));
1217   TF_RETURN_IF_ERROR(
1218       ExpectArray(scale_shape, "scale input of batch norm training"));
1219 
1220   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1221                Status::OK());
1222   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1223                Status::OK());
1224   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1225                Status::OK());
1226 
1227   if (feature_index >= operand_shape.rank()) {
1228     return InvalidArgument(
1229         "Expected feature_index of batch-norm-training to be "
1230         "smaller than the rank of operand_shape; "
1231         "got feature_index %d, and rank %d.",
1232         feature_index, operand_shape.rank());
1233   }
1234 
1235   if (feature_index < 0) {
1236     return InvalidArgument(
1237         "Expected feature_index of batch-norm-training to "
1238         "be a non-negative number, got %d.",
1239         feature_index);
1240   }
1241 
1242   if (operand_shape.rank() < 1) {
1243     return InvalidArgument(
1244         "Expected the rank of operand to "
1245         "batch-norm-training to be at least 1; got %d.",
1246         operand_shape.rank());
1247   }
1248 
1249   if (offset_shape.rank() != 1) {
1250     return InvalidArgument(
1251         "Offset input of batch-norm-training must have"
1252         " rank 1, but has rank %d.",
1253         offset_shape.rank());
1254   }
1255 
1256   if (scale_shape.rank() != 1) {
1257     return InvalidArgument(
1258         "Scale input of batch-norm-training must have"
1259         " rank 1, but has rank %d.",
1260         scale_shape.rank());
1261   }
1262 
1263   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1264     return InvalidArgument(
1265         "The operand to batch-norm-training must have a floating point "
1266         "element type, but the shape is %s.",
1267         PrimitiveType_Name(operand_shape.element_type()));
1268   }
1269 
1270   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1271                                                      operand_shape)) {
1272     return InvalidArgument(
1273         "The inputs should have the same element type for batch-norm-training, "
1274         "but the shape of offset factor is %s "
1275         "and the shape of operand is %s.",
1276         PrimitiveType_Name(offset_shape.element_type()),
1277         PrimitiveType_Name(operand_shape.element_type()));
1278   }
1279 
1280   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1281                                                      operand_shape)) {
1282     return InvalidArgument(
1283         "The inputs should have the same element type for batch-norm-training, "
1284         "but the shape of scale factor is %s "
1285         "and the shape of operand is %s.",
1286         PrimitiveType_Name(scale_shape.element_type()),
1287         PrimitiveType_Name(operand_shape.element_type()));
1288   }
1289 
1290   const int64 feature_count = operand_shape.dimensions(feature_index);
1291   Shape output_shape_for_mean_and_var =
1292       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1293 
1294   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1295     return InvalidArgument(
1296         "The size of offset factor should be the same as feature count,"
1297         "but the size of offset factor is %d "
1298         "and the feature count is %d.",
1299         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1300   }
1301 
1302   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1303     return InvalidArgument(
1304         "The size of scale factor should be the same as feature count,"
1305         "but the size of scale factor is %d "
1306         "and the feature count is %d.",
1307         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1308   }
1309 
1310   return ShapeUtil::MakeTupleShape({operand_shape,
1311                                     output_shape_for_mean_and_var,
1312                                     output_shape_for_mean_and_var});
1313 }
1314 
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1315 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1316     const Shape& operand_shape, const Shape& scale_shape,
1317     const Shape& offset_shape, const Shape& mean_shape,
1318     const Shape& variance_shape, int64 feature_index) {
1319   TF_RETURN_IF_ERROR(
1320       ExpectArray(operand_shape, "operand of batch norm inference"));
1321   TF_RETURN_IF_ERROR(
1322       ExpectArray(offset_shape, "offset input of batch norm inference"));
1323   TF_RETURN_IF_ERROR(
1324       ExpectArray(scale_shape, "scale input of batch norm inference"));
1325 
1326   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1327   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1328   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1329   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1330   TF_RETURN_IF_ERROR(
1331       ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1332 
1333   if (feature_index >= operand_shape.rank()) {
1334     return InvalidArgument(
1335         "Expected feature_index of batch-norm-inference to be "
1336         "smaller than the rank of operand_shape; "
1337         "got feature_index %d, and rank %d.",
1338         feature_index, operand_shape.rank());
1339   }
1340 
1341   if (feature_index < 0) {
1342     return InvalidArgument(
1343         "Expected feature_index of batch-norm-inference to "
1344         "be a non-negative number, got %d.",
1345         feature_index);
1346   }
1347 
1348   if (operand_shape.rank() < 1) {
1349     return InvalidArgument(
1350         "Expected the rank of operand to "
1351         "batch-norm-inference to be at least 1; got %d.",
1352         operand_shape.rank());
1353   }
1354 
1355   if (offset_shape.rank() != 1) {
1356     return InvalidArgument(
1357         "Offset input of batch-norm-inference must have"
1358         " rank 1, but has rank %d.",
1359         offset_shape.rank());
1360   }
1361 
1362   if (scale_shape.rank() != 1) {
1363     return InvalidArgument(
1364         "Scale input of batch-norm-inference must have"
1365         " rank 1, but has rank %d.",
1366         scale_shape.rank());
1367   }
1368 
1369   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1370     return InvalidArgument(
1371         "The operand to batch-norm-inference must have a floating point "
1372         "element type, but the shape is %s.",
1373         PrimitiveType_Name(operand_shape.element_type()));
1374   }
1375 
1376   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1377                                                      operand_shape)) {
1378     return InvalidArgument(
1379         "The inputs should have the same element type for "
1380         "batch-norm-inference, "
1381         "but the shape of offset factor is %s "
1382         "and the shape of operand is %s.",
1383         PrimitiveType_Name(offset_shape.element_type()),
1384         PrimitiveType_Name(operand_shape.element_type()));
1385   }
1386 
1387   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1388                                                      operand_shape)) {
1389     return InvalidArgument(
1390         "The inputs should have the same element type for "
1391         "batch-norm-inference, "
1392         "but the shape of scale factor is %s "
1393         "and the shape of operand is %s.",
1394         PrimitiveType_Name(scale_shape.element_type()),
1395         PrimitiveType_Name(operand_shape.element_type()));
1396   }
1397 
1398   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1399                                                      operand_shape)) {
1400     return InvalidArgument(
1401         "The inputs should have the same element type for "
1402         "batch-norm-inference, "
1403         "but the shape of mean is %s "
1404         "and the shape of operand is %s.",
1405         PrimitiveType_Name(mean_shape.element_type()),
1406         PrimitiveType_Name(operand_shape.element_type()));
1407   }
1408 
1409   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1410                                                      operand_shape)) {
1411     return InvalidArgument(
1412         "The inputs should have the same element type for "
1413         "batch-norm-inference, "
1414         "but the shape of variance is %s "
1415         "and the shape of operand is %s.",
1416         PrimitiveType_Name(mean_shape.element_type()),
1417         PrimitiveType_Name(variance_shape.element_type()));
1418   }
1419 
1420   const int64 feature_count = operand_shape.dimensions(feature_index);
1421   Shape output_shape_for_mean_and_var =
1422       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1423 
1424   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1425     return InvalidArgument(
1426         "The size of offset factor should be the same as feature count,"
1427         "but the size of offset factor is %d "
1428         "and the feature count is %d.",
1429         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1430   }
1431 
1432   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1433     return InvalidArgument(
1434         "The size of scale factor should be the same as feature count,"
1435         "but the size of scale factor is %d "
1436         "and the feature count is %d.",
1437         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1438   }
1439 
1440   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1441     return InvalidArgument(
1442         "The size of mean should be the same as feature count,"
1443         "but the size of mean is %d "
1444         "and the feature count is %d.",
1445         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1446   }
1447 
1448   if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1449     return InvalidArgument(
1450         "The size of variance should be the same as feature count,"
1451         "but the size of variance is %d "
1452         "and the feature count is %d.",
1453         ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1454   }
1455 
1456   return operand_shape;
1457 }
1458 
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1459 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1460     const Shape& operand_shape, const Shape& scale_shape,
1461     const Shape& mean_shape, const Shape& var_shape,
1462     const Shape& output_grad_shape, int64 feature_index) {
1463   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1464   TF_RETURN_IF_ERROR(
1465       ExpectArray(scale_shape, "scale input of batch norm grad"));
1466   TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1467   TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1468   TF_RETURN_IF_ERROR(
1469       ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1470 
1471   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1472   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1473   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1474   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1475   TF_RETURN_IF_ERROR(
1476       ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1477 
1478   if (feature_index >= operand_shape.rank()) {
1479     return InvalidArgument(
1480         "Expected feature_index of batch-norm-grad to be "
1481         "smaller than the rank of operand_shape; "
1482         "got feature_index %d, and rank %d.",
1483         feature_index, operand_shape.rank());
1484   }
1485 
1486   if (operand_shape.rank() != output_grad_shape.rank()) {
1487     return InvalidArgument(
1488         "Expected operand_shape of batch-norm-grad to have the same rank as"
1489         " output_grad_shape; got rank(oprand_shape) %d, and"
1490         " rank(output_grad_shape) %d.",
1491         operand_shape.rank(), output_grad_shape.rank());
1492   }
1493 
1494   if (mean_shape.rank() != 1) {
1495     return InvalidArgument(
1496         "Mean input of batch-norm-grad must have"
1497         " rank 1, but has rank %d.",
1498         mean_shape.rank());
1499   }
1500 
1501   if (scale_shape.rank() != 1) {
1502     return InvalidArgument(
1503         "Scale input of batch-norm-grad must have"
1504         " rank 1, but has rank %d.",
1505         scale_shape.rank());
1506   }
1507 
1508   if (var_shape.rank() != 1) {
1509     return InvalidArgument(
1510         "Var input of batch-norm-grad must have"
1511         " rank 1, but has rank %d.",
1512         var_shape.rank());
1513   }
1514 
1515   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1516     return InvalidArgument(
1517         "The operand to batch-norm-grad must have a floating point "
1518         "element type, but the shape is %s.",
1519         PrimitiveType_Name(operand_shape.element_type()));
1520   }
1521 
1522   if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1523     return InvalidArgument(
1524         "The output_grad to batch-norm-grad must have a floating point "
1525         "element type, but the shape is %s.",
1526         PrimitiveType_Name(output_grad_shape.element_type()));
1527   }
1528 
1529   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1530                                                      operand_shape)) {
1531     return InvalidArgument(
1532         "The inputs should have the same element type for batch-norm-grad, "
1533         "but the element type of output_grad is %s "
1534         "and the element type of operand is %s.",
1535         PrimitiveType_Name(output_grad_shape.element_type()),
1536         PrimitiveType_Name(operand_shape.element_type()));
1537   }
1538 
1539   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1540                                                      operand_shape)) {
1541     return InvalidArgument(
1542         "The inputs should have the same element type for batch-norm-grad, "
1543         "but the element type of scale factor is %s "
1544         "and the element type of operand is %s.",
1545         PrimitiveType_Name(scale_shape.element_type()),
1546         PrimitiveType_Name(operand_shape.element_type()));
1547   }
1548 
1549   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1550                                                      operand_shape)) {
1551     return InvalidArgument(
1552         "The inputs should have the same element type for batch-norm-grad, "
1553         "but the element type of mean is %s "
1554         "and the element type of operand is %s.",
1555         PrimitiveType_Name(mean_shape.element_type()),
1556         PrimitiveType_Name(operand_shape.element_type()));
1557   }
1558 
1559   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1560                                                      operand_shape)) {
1561     return InvalidArgument(
1562         "The inputs should have the same element type for batch-norm-grad, "
1563         "but the element type of mean is %s "
1564         "and the element type of operand is %s.",
1565         PrimitiveType_Name(mean_shape.element_type()),
1566         PrimitiveType_Name(operand_shape.element_type()));
1567   }
1568 
1569   const int64 feature_count = operand_shape.dimensions(feature_index);
1570 
1571   Shape feature_shape =
1572       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1573 
1574   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1575     return InvalidArgument(
1576         "The size of mean should be the same as feature count,"
1577         "but the size of offset factor is %d "
1578         "and the feature count is %d.",
1579         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1580   }
1581 
1582   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1583     return InvalidArgument(
1584         "The size of scale factor should be the same as feature count,"
1585         "but the size of scale factor is %d "
1586         "and the feature count is %d.",
1587         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1588   }
1589 
1590   if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1591     return InvalidArgument(
1592         "The size of variance should be the same as feature count,"
1593         "but the size of variance is %d "
1594         "and the feature count is %d.",
1595         ShapeUtil::GetDimension(var_shape, 0), feature_count);
1596   }
1597 
1598   // Verify operand_shape and output_grad_shape have same bounds.
1599   for (int64 i = 0; i < operand_shape.rank(); ++i) {
1600     if (ShapeUtil::GetDimension(operand_shape, i) !=
1601         ShapeUtil::GetDimension(output_grad_shape, i)) {
1602       return InvalidArgument(
1603           "The bounds of operand shape should be the same as output_grad's,"
1604           "but the bound of operand_shape at dimension %d is %d "
1605           "and the bound of output_grad_shape is %d.",
1606           i, ShapeUtil::GetDimension(operand_shape, i),
1607           ShapeUtil::GetDimension(output_grad_shape, i));
1608     }
1609   }
1610 
1611   return ShapeUtil::MakeTupleShape(
1612       {operand_shape, feature_shape, feature_shape});
1613 }
1614 
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums,absl::optional<PrimitiveType> preferred_element_type)1615 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1616     const Shape& lhs, const Shape& rhs, int64 feature_group_count,
1617     int64 batch_group_count, const Window& window,
1618     const ConvolutionDimensionNumbers& dnums,
1619     absl::optional<PrimitiveType> preferred_element_type) {
1620   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1621   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1622 
1623   if (feature_group_count <= 0) {
1624     return InvalidArgument(
1625         "feature_group_count must be a positive number, got %d",
1626         feature_group_count);
1627   }
1628 
1629   if (batch_group_count <= 0) {
1630     return InvalidArgument(
1631         "batch_group_count must be a positive number, got %d",
1632         batch_group_count);
1633   }
1634 
1635   if (batch_group_count > 1 && feature_group_count > 1) {
1636     return InvalidArgument(
1637         "both batch_group_count %d and feature_group_count %d cannot be "
1638         "greater than 1",
1639         batch_group_count, feature_group_count);
1640   }
1641 
1642   if (dnums.input_spatial_dimensions_size() !=
1643       dnums.kernel_spatial_dimensions_size()) {
1644     return InvalidArgument(
1645         "Both arguments to convolution must have same number of dimensions.\n"
1646         "Numbers: %s",
1647         dnums.DebugString());
1648   }
1649 
1650   if (dnums.input_spatial_dimensions_size() !=
1651       dnums.output_spatial_dimensions_size()) {
1652     return InvalidArgument(
1653         "Both input and output of convolution must have same number of "
1654         "dimensions.\nNumbers: %s",
1655         dnums.DebugString());
1656   }
1657 
1658   const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1659   if (window.dimensions_size() != num_spatial_dims) {
1660     return InvalidArgument(
1661         "Window must have same number of dimensions as dimension numbers.\n"
1662         "Window: %s\nDimension numbers: %s.",
1663         window.DebugString(), dnums.DebugString());
1664   }
1665 
1666   const int num_dims = num_spatial_dims + 2;
1667   if (lhs.rank() != num_dims) {
1668     return InvalidArgument(
1669         "The LHS argument to a convolution should have rank %d; lhs: %s.",
1670         num_dims, ShapeUtil::HumanString(lhs));
1671   }
1672   if (rhs.rank() != num_dims) {
1673     return InvalidArgument(
1674         "The RHS argument to a convolution should have rank %d; rhs: %s.",
1675         num_dims, ShapeUtil::HumanString(rhs));
1676   }
1677   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1678   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1679 
1680   // Verifies that the input and window dimensions are a permutation of
1681   // the dimension numbers.
1682   std::vector<int64> input_dnums(num_dims);
1683   input_dnums[0] = dnums.input_batch_dimension();
1684   input_dnums[1] = dnums.input_feature_dimension();
1685   std::copy(dnums.input_spatial_dimensions().begin(),
1686             dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1687   absl::c_sort(input_dnums);
1688 
1689   std::vector<int64> window_dnums(num_dims);
1690   window_dnums[0] = dnums.kernel_input_feature_dimension();
1691   window_dnums[1] = dnums.kernel_output_feature_dimension();
1692   std::copy(dnums.kernel_spatial_dimensions().begin(),
1693             dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1694   absl::c_sort(window_dnums);
1695 
1696   std::vector<int64> output_dnums(num_dims);
1697   output_dnums[0] = dnums.output_batch_dimension();
1698   output_dnums[1] = dnums.output_feature_dimension();
1699   std::copy(dnums.output_spatial_dimensions().begin(),
1700             dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1701   absl::c_sort(output_dnums);
1702 
1703   std::vector<int64> expected_dnums(num_dims);
1704   std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1705 
1706   const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1707   if (!absl::c_all_of(input_dnums, in_range) ||
1708       !absl::c_all_of(window_dnums, in_range) ||
1709       !absl::c_all_of(output_dnums, in_range)) {
1710     return InvalidArgument(
1711         "A dimension number is out of range in convolution: %s.",
1712         dnums.DebugString());
1713   }
1714 
1715   if (input_dnums != expected_dnums) {
1716     return InvalidArgument(
1717         "Input dimensions of convolution must contain each dimension exactly "
1718         "once: %s.",
1719         dnums.DebugString());
1720   }
1721   if (window_dnums != expected_dnums) {
1722     return InvalidArgument(
1723         "Window dimensions of convolution must contain each dimension exactly "
1724         "once: %s.",
1725         dnums.DebugString());
1726   }
1727   if (output_dnums != expected_dnums) {
1728     return InvalidArgument(
1729         "Output dimensions of convolution must contain each dimension exactly "
1730         "once: %s.",
1731         dnums.DebugString());
1732   }
1733 
1734   std::vector<int64> input_spatial_dims(num_spatial_dims);
1735   for (int i = 0; i < num_spatial_dims; ++i) {
1736     input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1737   }
1738   const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1739   const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1740 
1741   std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1742   for (int i = 0; i < num_spatial_dims; ++i) {
1743     kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1744   }
1745   const int64 kernel_input_features =
1746       rhs.dimensions(dnums.kernel_input_feature_dimension());
1747   const int64 kernel_output_features =
1748       rhs.dimensions(dnums.kernel_output_feature_dimension());
1749 
1750   if (kernel_output_features % batch_group_count != 0) {
1751     return InvalidArgument(
1752         "Expected output feature dimension size (value %d) to be a multiple of "
1753         "batch group count %d; got <conv>(%s, %s)\n"
1754         "Dimension numbers: {%s}.",
1755         kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1756         ShapeUtil::HumanString(rhs), dnums.DebugString());
1757   }
1758 
1759   if (input_features % feature_group_count != 0 ||
1760       input_features / feature_group_count != kernel_input_features) {
1761     return InvalidArgument(
1762         "Expected LHS feature dimension (value %d) to be a multiple of "
1763         "feature_group_count (value %d), and LHS feature dimension / "
1764         "feature_group_count = RHS feature dimension (value %d); "
1765         "got <conv>(%s, %s)\n"
1766         "Dimension numbers: {%s}.",
1767         input_features, feature_group_count, kernel_input_features,
1768         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1769         dnums.DebugString());
1770   }
1771 
1772   if (kernel_output_features % feature_group_count > 0) {
1773     // A depthwise/grouped filter has the shape
1774     // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1775     // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1776     // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1777     // feature count (which is equal to kernel output features) has to be a
1778     // multiple of feature_group_count.
1779     return InvalidArgument(
1780         "Expected output feature dimension (value %d) to be divisible by "
1781         "feature_group_count (value %d); "
1782         "got <conv>(%s, %s)\n"
1783         "Dimension numbers: {%s}.",
1784         kernel_output_features, feature_group_count,
1785         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1786         dnums.DebugString());
1787   }
1788 
1789   if (input_batch % batch_group_count != 0) {
1790     return InvalidArgument(
1791         "Expected input batch dimension (value %d) to be divisible by "
1792         "batch_group_count (value %d); "
1793         "got <conv>(%s, %s)\n"
1794         "Dimension numbers: {%s}.",
1795         input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1796         ShapeUtil::HumanString(rhs), dnums.DebugString());
1797   }
1798 
1799   std::vector<int64> window_dims(num_spatial_dims);
1800   for (int i = 0; i < num_spatial_dims; ++i) {
1801     window_dims[i] = window.dimensions(i).size();
1802   }
1803   if (kernel_spatial_dims != window_dims) {
1804     return InvalidArgument(
1805         "Window dimensions do not match RHS shape:\n\t"
1806         "RHS shape: %s\n\t"
1807         "Window: {%s}\n\t"
1808         "Dimension numbers: {%s}.",
1809         ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1810         dnums.ShortDebugString());
1811   }
1812 
1813   Shape base_shape =
1814       ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1815   TF_ASSIGN_OR_RETURN(
1816       Shape window_output_shape,
1817       InferWindowOutputShape(base_shape, window, lhs.element_type(),
1818                              /*allow_negative_padding=*/true));
1819 
1820   std::vector<int64> dimensions(num_dims);
1821   dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1822   dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1823 
1824   for (int i = 0; i < num_spatial_dims; ++i) {
1825     dimensions[dnums.output_spatial_dimensions(i)] =
1826         window_output_shape.dimensions(i);
1827   }
1828   std::vector<bool> is_dynamic(num_dims);
1829   for (int i = 0; i < num_dims; i++) {
1830     if (lhs.is_dynamic_dimension(i)) {
1831       if (i == dnums.input_batch_dimension()) {
1832         is_dynamic[dnums.output_batch_dimension()] = true;
1833       } else if (i == dnums.input_feature_dimension()) {
1834         // Input feature dimension is a contracting dimension, which does not
1835         // affect the output dimension size. So we need to do nothing.
1836       } else {
1837         for (int64 j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
1838           if (i == dnums.input_spatial_dimensions(j)) {
1839             // i is a spatial dimension, find corresponding output spatial
1840             // dimension.
1841             is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1842           }
1843         }
1844       }
1845     }
1846     if (rhs.is_dynamic_dimension(i)) {
1847       if (i == dnums.kernel_input_feature_dimension()) {
1848         // Kernel feature dimension does not affect the output dimension size.
1849         // So we need to do nothing.
1850       } else if (i == dnums.kernel_output_feature_dimension()) {
1851         return InvalidArgument(
1852             "Dynamic output feature dim on convolution kernel is not "
1853             "supported: rhs shape is %s ",
1854             rhs.ToString());
1855       } else {
1856         for (int64 j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
1857           if (i == dnums.kernel_spatial_dimensions(j)) {
1858             // i is a spatial dimension, find corresponding output spatial
1859             // dimension.
1860             is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1861           }
1862         }
1863       }
1864     }
1865   }
1866   TF_ASSIGN_OR_RETURN(
1867       PrimitiveType type,
1868       MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1869                   preferred_element_type));
1870   return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
1871 }
1872 
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1873 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1874     const Shape& in, const FftType fft_type,
1875     const absl::Span<const int64> fft_length) {
1876   const int64 fft_rank = fft_length.size();
1877   if (fft_rank < 1 || fft_rank > 3) {
1878     return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1879   }
1880 #define RET_CHECK_RANK(x)                            \
1881   if (x.dimensions_size() < fft_rank) {              \
1882     return InvalidArgument(                          \
1883         "FFT of rank %d requires input of at least " \
1884         "same rank; got input of rank %d",           \
1885         fft_rank, x.dimensions_size());              \
1886   }
1887   switch (fft_type) {
1888     case FFT:
1889     case IFFT:
1890       if (!primitive_util::IsComplexType(in.element_type())) {
1891         return InvalidArgument("%s requires complex input type, found %s.",
1892                                FftType_Name(fft_type),
1893                                PrimitiveType_Name(in.element_type()));
1894       }
1895       RET_CHECK_RANK(in);
1896       return in;
1897     case RFFT: {
1898       if (in.element_type() != F32 && in.element_type() != F64) {
1899         return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
1900                                PrimitiveType_Name(in.element_type()));
1901       }
1902       RET_CHECK_RANK(in);
1903       for (int i = 0; i < fft_rank; i++) {
1904         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1905             fft_length[i]) {
1906           return InvalidArgument(
1907               "RFFT requires innermost dimensions match fft_length but "
1908               "dimension %d is %d and should be %d.",
1909               in.dimensions_size() - fft_rank + i,
1910               in.dimensions(in.dimensions_size() - fft_rank + i),
1911               fft_length[i]);
1912         }
1913       }
1914       Shape result = ShapeUtil::ChangeElementType(
1915           in, in.element_type() == F32 ? C64 : C128);
1916       // Preserve the size of zero-sized dimensions.
1917       if (fft_length[fft_rank - 1] != 0) {
1918         result.set_dimensions(result.dimensions_size() - 1,
1919                               fft_length[fft_rank - 1] / 2 + 1);
1920       }
1921       return result;
1922     }
1923     case IRFFT: {
1924       if (!primitive_util::IsComplexType(in.element_type())) {
1925         return InvalidArgument("IRFFT requires complex input type, found %s.",
1926                                PrimitiveType_Name(in.element_type()));
1927       }
1928       RET_CHECK_RANK(in);
1929       Shape result = ShapeUtil::ComplexComponentShape(in);
1930       for (int i = 0; i < fft_rank - 1; i++) {
1931         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1932             fft_length[i]) {
1933           return InvalidArgument(
1934               "IRFFT requires all but one innermost dimensions match "
1935               "fft_length, but dimension %d is %d and should be %d.",
1936               in.dimensions_size() - fft_rank + i,
1937               in.dimensions(in.dimensions_size() - fft_rank + i),
1938               fft_length[i]);
1939         }
1940       }
1941       // The size of zero-sized dimensions is preserved.
1942       if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1943            fft_length[fft_rank - 1] != 0) &&
1944           in.dimensions(in.dimensions_size() - 1) !=
1945               fft_length[fft_rank - 1] / 2 + 1) {
1946         return InvalidArgument(
1947             "IRFFT requires innermost dimension matches fft_length/2+1, but "
1948             "dimension %d is %d and should be %d.",
1949             in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1950             fft_length[fft_rank - 1] / 2 + 1);
1951       }
1952       result.set_dimensions(result.dimensions_size() - 1,
1953                             fft_length[fft_rank - 1]);
1954       return result;
1955     }
1956     default:
1957       LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1958   }
1959 #undef RET_CHECK_RANK
1960 }
1961 
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1962 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1963     const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1964   if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1965       a.element_type() != b.element_type()) {
1966     return InvalidArgument(
1967         "Expected element types in shape to be floating or complex and "
1968         "identical for TriangularSolve; got %s and %s.",
1969         PrimitiveType_Name(a.element_type()),
1970         PrimitiveType_Name(b.element_type()));
1971   }
1972   if (a.rank() < 2) {
1973     return InvalidArgument(
1974         "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1975         a.ToString());
1976   }
1977   if (b.rank() != a.rank()) {
1978     return InvalidArgument(
1979         "Arguments to triangular solve must have equal rank; got %s and %s.",
1980         b.ToString(), a.ToString());
1981   }
1982   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1983     return InvalidArgument(
1984         "The two minor dimensions of 'a' must have equal size, got %s.",
1985         a.ToString());
1986   }
1987   if (a.dimensions(a.rank() - 1) !=
1988       b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1989     return InvalidArgument(
1990         "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1991         "%s",
1992         a.ToString(), b.ToString());
1993   }
1994   absl::Span<const int64> a_batch_dims(a.dimensions());
1995   absl::Span<const int64> b_batch_dims(b.dimensions());
1996   a_batch_dims.remove_suffix(2);
1997   b_batch_dims.remove_suffix(2);
1998   if (a_batch_dims != b_batch_dims) {
1999     return InvalidArgument(
2000         "The leading batch dimensions of the arguments to triangular solve "
2001         "must be equal; got %s and %s.",
2002         b.ToString(), a.ToString());
2003   }
2004   if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
2005       options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
2006     return InvalidArgument(
2007         "Invalid transpose option value for triangular solve (%d).\n",
2008         options.transpose_a());
2009   }
2010   return b;
2011 }
2012 
InferCholeskyShape(const Shape & a)2013 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
2014     const Shape& a) {
2015   if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
2016     return InvalidArgument(
2017         "Expected element type in shape to be floating or complex for "
2018         "Cholesky; got %s.",
2019         PrimitiveType_Name(a.element_type()));
2020   }
2021   if (a.rank() < 2) {
2022     return InvalidArgument(
2023         "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
2024         a.ToString());
2025   }
2026   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
2027     return InvalidArgument(
2028         "The two minor dimensions of 'a' must have equal size, got %s.",
2029         a.ToString());
2030   }
2031   return a;
2032 }
2033 
InferAllGatherShape(const Shape & operand_shape,int64 all_gather_dimension,int64 shard_count)2034 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
2035     const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) {
2036   TF_RET_CHECK(all_gather_dimension >= 0);
2037   TF_RET_CHECK(all_gather_dimension < operand_shape.rank());
2038   TF_RET_CHECK(shard_count > 0);
2039   auto shape = operand_shape;
2040   shape.set_dimensions(all_gather_dimension,
2041                        shard_count * shape.dimensions(all_gather_dimension));
2042   return shape;
2043 }
2044 
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2045 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2046     absl::Span<const Shape* const> operand_shapes) {
2047   for (const Shape* operand_shape : operand_shapes) {
2048     TF_RETURN_IF_ERROR(
2049         ExpectArray(*operand_shape, "operand of cross replica sum"));
2050   }
2051   if (operand_shapes.size() == 1) {
2052     return *operand_shapes[0];
2053   }
2054   std::vector<Shape> operand_shape_values;
2055   for (const Shape* operand_shape : operand_shapes) {
2056     operand_shape_values.push_back(*operand_shape);
2057   }
2058   return ShapeUtil::MakeTupleShape(operand_shape_values);
2059 }
2060 
InferAllToAllShape(const Shape & shape,int64 split_dimension,int64 concat_dimension,int64 split_count)2061 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2062     const Shape& shape, int64 split_dimension, int64 concat_dimension,
2063     int64 split_count) {
2064   TF_RET_CHECK(split_count > 0);
2065   if (split_dimension >= shape.rank() || split_dimension < 0) {
2066     return InvalidArgument(
2067         "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2068         split_dimension, ShapeUtil::HumanString(shape));
2069   }
2070   if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2071     return InvalidArgument(
2072         "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2073         concat_dimension, ShapeUtil::HumanString(shape));
2074   }
2075   if (shape.dimensions(split_dimension) % split_count != 0) {
2076     return InvalidArgument(
2077         "AllToAll split dimension size %d must be dividable by split_count "
2078         "%d.",
2079         shape.dimensions(split_dimension), split_count);
2080   }
2081   std::vector<int64> new_dimensions(shape.dimensions().begin(),
2082                                     shape.dimensions().end());
2083   new_dimensions[split_dimension] /= split_count;
2084   new_dimensions[concat_dimension] *= split_count;
2085   return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2086 }
2087 
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2088 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2089     absl::Span<const Shape* const> operand_shapes) {
2090   // An Alltoall HLO instruction receives N operands (with the same shape) and
2091   // returns a tuple that contains N array shapes.
2092   TF_RET_CHECK(!operand_shapes.empty());
2093   for (int i = 0; i < operand_shapes.size(); i++) {
2094     if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
2095                                                     *operand_shapes[i])) {
2096       return InvalidArgument(
2097           "HLO all-to-all has operands with different shapes: the 0th "
2098           "operand shape %s, but the %dth operand has shape %s.",
2099           ShapeUtil::HumanString(*operand_shapes[0]), i,
2100           ShapeUtil::HumanString(*operand_shapes[i]));
2101     }
2102   }
2103 
2104   return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2105 }
2106 
InferCollectivePermuteShape(const Shape & shape)2107 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2108     const Shape& shape) {
2109   TF_RET_CHECK(shape.IsArray());
2110   return shape;
2111 }
2112 
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2113 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2114     absl::Span<const Shape* const> arg_shapes,
2115     absl::Span<const int64> dimensions_to_reduce,
2116     const ProgramShape& to_apply) {
2117   if (arg_shapes.empty()) {
2118     return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2119   }
2120   if (arg_shapes.size() % 2) {
2121     return InvalidArgument(
2122         "Reduce must have an even number of arguments, has %lu",
2123         arg_shapes.size());
2124   }
2125   int64 num_reduced_args = arg_shapes.size() / 2;
2126   auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2127   // Check that all of the reduced tensors have the same dimensions. The element
2128   // types may be different.
2129   for (int64 i = 1; i < num_reduced_args; ++i) {
2130     if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2131       return InvalidArgument(
2132           "All reduced tensors must have the same dimension. Tensor 0 has "
2133           "shape %s, Tensor %d has shape %s",
2134           ShapeUtil::HumanString(*reduced_args[0]), i,
2135           ShapeUtil::HumanString(*reduced_args[i]));
2136     }
2137   }
2138   // Check that the dimensions to reduce are in-bounds for the given shape.
2139   // We've already verified all reduced tensors have the same dimensions, so it
2140   // doesn't matter which one we choose.
2141   const Shape& arg = *reduced_args[0];
2142   for (int64 dimension : dimensions_to_reduce) {
2143     if (dimension >= arg.rank() || dimension < 0) {
2144       return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2145                              dimension, ShapeUtil::HumanString(arg));
2146     }
2147   }
2148 
2149   auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2150   std::vector<PrimitiveType> element_types;
2151   for (const Shape* arg : reduced_args) {
2152     element_types.push_back(arg->element_type());
2153   }
2154   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2155                                         num_reduced_args));
2156 
2157   absl::flat_hash_set<int64> dimensions_to_reduce_set;
2158   for (int64 dim_to_reduce : dimensions_to_reduce) {
2159     if (!dimensions_to_reduce_set.insert(dim_to_reduce).second) {
2160       return InvalidArgument("Duplicate reduction dimension: %d",
2161                              dim_to_reduce);
2162     }
2163   }
2164 
2165   std::vector<int64> new_dimensions;
2166   std::vector<bool> new_is_dynamic;
2167   for (int i = 0; i < arg.rank(); ++i) {
2168     if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2169       new_dimensions.push_back(arg.dimensions(i));
2170       new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2171     }
2172   }
2173 
2174   if (ShapeUtil::IsScalar(to_apply.result())) {
2175     return ShapeUtil::MakeShape(to_apply.result().element_type(),
2176                                 new_dimensions, new_is_dynamic);
2177   } else {
2178     std::vector<Shape> result_subshapes;
2179     for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2180       result_subshapes.push_back(ShapeUtil::MakeShape(
2181           subshape.element_type(), new_dimensions, new_is_dynamic));
2182     }
2183     return ShapeUtil::MakeTupleShape(result_subshapes);
2184   }
2185 }
2186 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2187 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2188     const Shape& operand_shape, const Shape& init_value_shape,
2189     const Window& window, const ProgramShape& to_apply_shape) {
2190   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2191                                         {operand_shape.element_type()},
2192                                         /*inputs=*/1));
2193   return InferReduceWindowShape(operand_shape, init_value_shape, window);
2194 }
2195 
InferReduceWindowShape(absl::Span<const Shape * const> operands,absl::Span<const Shape * const> init_values,const Window & window,const ProgramShape & to_apply_shape)2196 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2197     absl::Span<const Shape* const> operands,
2198     absl::Span<const Shape* const> init_values, const Window& window,
2199     const ProgramShape& to_apply_shape) {
2200   auto number_of_input = operands.size();
2201   // Check that all of the reduced tensors have the same dimensions. The element
2202   // types may be different.
2203   for (int64 i = 1; i < number_of_input; ++i) {
2204     if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
2205       return InvalidArgument(
2206           "All reduced tensors must have the same dimension. Tensor 0 has "
2207           "shape %s, Tensor %d has shape %s",
2208           ShapeUtil::HumanString(*operands[0]), i,
2209           ShapeUtil::HumanString(*operands[i]));
2210     }
2211   }
2212   std::vector<PrimitiveType> operand_element_type_vec;
2213   for (const Shape* s : operands) {
2214     operand_element_type_vec.push_back(s->element_type());
2215   }
2216   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
2217                                         operand_element_type_vec,
2218                                         /*inputs=*/number_of_input));
2219   std::vector<Shape> output_shape_vec;
2220   for (int i = 0; i < operands.size(); ++i) {
2221     TF_ASSIGN_OR_RETURN(
2222         auto cur_output_shape,
2223         InferReduceWindowShape(*operands[i], *init_values[i], window));
2224     output_shape_vec.push_back(cur_output_shape);
2225   }
2226   if (ShapeUtil::IsScalar(to_apply_shape.result())) {
2227     CHECK_EQ(output_shape_vec.size(), 1);
2228     return output_shape_vec[0];
2229   } else {
2230     return ShapeUtil::MakeTupleShape(output_shape_vec);
2231   }
2232 }
2233 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2234 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2235     const Shape& operand_shape, const Shape& init_value_shape,
2236     const Window& window) {
2237   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2238   return InferWindowOutputShape(operand_shape, window,
2239                                 init_value_shape.element_type(),
2240                                 /*allow_negative_padding=*/false);
2241 }
2242 
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2243 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2244     const Shape& operand_shape, const ProgramShape& select_shape,
2245     const Window& window, const Shape& source_shape,
2246     const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2247   TF_RETURN_IF_ERROR(
2248       ExpectArray(operand_shape, "operand of select-and-scatter"));
2249 
2250   // Check if the select function has a proper shape of (T,T) -> PRED.
2251   if (select_shape.parameters_size() != 2) {
2252     return InvalidArgument(
2253         "Select function must take 2 parameters, but "
2254         "takes %d parameter(s).",
2255         select_shape.parameters_size());
2256   }
2257   const Shape& select_result_shape = select_shape.result();
2258   if (!ShapeUtil::Compatible(select_result_shape,
2259                              ShapeUtil::MakeShape(PRED, {}))) {
2260     return InvalidArgument("Select function must have rank-0 PRED result.");
2261   }
2262   const Shape& operand_element_shape =
2263       ShapeUtil::MakeShape(operand_shape.element_type(), {});
2264   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2265                                                 select_shape.parameters(0))) {
2266     return InvalidArgument(
2267         "Select function's first parameter shape currently must "
2268         "match the operand element shape, but got %s vs %s.",
2269         ShapeUtil::HumanString(select_shape.parameters(0)),
2270         ShapeUtil::HumanString(operand_element_shape));
2271   }
2272   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2273                                                 select_shape.parameters(1))) {
2274     return InvalidArgument(
2275         "Select function's second parameter shape currently must "
2276         "match the operand element shape, but got %s vs %s.",
2277         ShapeUtil::HumanString(select_shape.parameters(1)),
2278         ShapeUtil::HumanString(operand_element_shape));
2279   }
2280 
2281   // Check if the scatter function has a proper shape as a reduction.
2282   TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2283                                         {source_shape.element_type()},
2284                                         /*inputs=*/1));
2285 
2286   // Check if the result shape of window operation matches the source shape.
2287   TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2288                       InferWindowOutputShape(operand_shape, window,
2289                                              operand_shape.element_type(),
2290                                              /*allow_negative_padding=*/false));
2291   if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2292                                                 window_result_shape)) {
2293     return InvalidArgument(
2294         "Source shape does not match the shape of window-reduced operand: "
2295         "source(%s), window-reduced operand(%s).",
2296         ShapeUtil::HumanString(source_shape),
2297         ShapeUtil::HumanString(window_result_shape));
2298   }
2299 
2300   return operand_shape;
2301 }
2302 
InferGetDimensionSizeShape(const Shape & shape,int64 dimension)2303 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2304     const Shape& shape, int64 dimension) {
2305   if (dimension < 0 || dimension >= shape.rank()) {
2306     return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2307                            dimension);
2308   }
2309 
2310   // TODO(b/119580730): Remove this restriction when very large dimension size
2311   // is needed.
2312   if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2313     return InvalidArgument(
2314         "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2315         "INT_MAX limit.",
2316         ShapeUtil::HumanString(shape), dimension);
2317   }
2318 
2319   return ShapeUtil::MakeShape(S32, {});
2320 }
2321 
InferSetDimensionSizeShape(const Shape & shape,const Shape & val_shape,int64 dimension)2322 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2323     const Shape& shape, const Shape& val_shape, int64 dimension) {
2324   if (dimension < 0 || dimension >= shape.rank()) {
2325     return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2326                            dimension);
2327   }
2328 
2329   if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
2330     return InvalidArgument(
2331         "SetDimensionSize's value has to be S32 scalar, got %s",
2332         val_shape.ToString());
2333   }
2334   // TODO(b/119580730): Remove this restriction when very large dimension size
2335   // is needed.
2336   if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2337     return InvalidArgument(
2338         "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2339         "INT_MAX limit.",
2340         ShapeUtil::HumanString(shape), dimension);
2341   }
2342 
2343   Shape result = shape;
2344   result.set_dynamic_dimension(dimension, true);
2345   return result;
2346 }
2347 
InferWindowFromDimensions(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation)2348 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2349     absl::Span<const int64> window_dimensions,
2350     absl::Span<const int64> window_strides,
2351     absl::Span<const std::pair<int64, int64>> padding,
2352     absl::Span<const int64> lhs_dilation,
2353     absl::Span<const int64> rhs_dilation) {
2354   const auto verify_size = [&](const size_t x, const char* x_name) {
2355     if (x == 0 || x == window_dimensions.size()) {
2356       return Status::OK();
2357     } else {
2358       return InvalidArgument(
2359           "%s", absl::StrCat(
2360                     "Window has different number of window dimensions than of ",
2361                     x_name,
2362                     "\nNumber of window dimensions: ", window_dimensions.size(),
2363                     "\nNumber of ", x_name, ": ", x, "\n"));
2364     }
2365   };
2366   TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2367   TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2368   TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2369   TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2370 
2371   Window window;
2372   for (size_t i = 0; i < window_dimensions.size(); i++) {
2373     auto dim = window.add_dimensions();
2374     dim->set_size(window_dimensions[i]);
2375     if (!window_strides.empty()) {
2376       dim->set_stride(window_strides[i]);
2377     } else {
2378       dim->set_stride(1);
2379     }
2380     if (!padding.empty()) {
2381       dim->set_padding_low(padding[i].first);
2382       dim->set_padding_high(padding[i].second);
2383     } else {
2384       dim->set_padding_low(0);
2385       dim->set_padding_high(0);
2386     }
2387     if (!lhs_dilation.empty()) {
2388       dim->set_base_dilation(lhs_dilation[i]);
2389     } else {
2390       dim->set_base_dilation(1);
2391     }
2392     if (!rhs_dilation.empty()) {
2393       dim->set_window_dilation(rhs_dilation[i]);
2394     } else {
2395       dim->set_window_dilation(1);
2396     }
2397     dim->set_window_reversal(false);
2398   }
2399   return window;
2400 }
2401 
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2402 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2403     const Shape& arg, absl::Span<const int64> starts,
2404     absl::Span<const int64> limits, absl::Span<const int64> strides) {
2405   auto error = [&](const string& message) {
2406     return InvalidArgument(
2407         "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2408         "{%s}; strides: {%s}.",
2409         message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2410         StrJoin(limits, ","), StrJoin(strides, ","));
2411   };
2412   TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2413   VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2414                        ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2415                        StrJoin(limits, ", "));
2416 
2417   if (starts.size() != limits.size()) {
2418     return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2419                            starts.size(), limits.size()));
2420   }
2421 
2422   if (starts.size() != strides.size()) {
2423     return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2424                            starts.size(), strides.size()));
2425   }
2426 
2427   if (starts.size() != arg.rank()) {
2428     return InvalidArgument(
2429         "Slice index count does not match argument rank: %u vs %d.",
2430         starts.size(), arg.rank());
2431   }
2432 
2433   std::vector<int64> sizes;
2434   for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
2435     int64 start_index = starts[dimension];
2436     int64 limit_index = limits[dimension];
2437     int64 stride = strides[dimension];
2438     if (start_index < 0) {
2439       return InvalidArgument("Negative start index to slice: %d.", start_index);
2440     }
2441     if (limit_index > arg.dimensions(dimension)) {
2442       return error(
2443           StrFormat("limit index (%d) must be less than or equal to dimension "
2444                     "size (%d)",
2445                     limit_index, arg.dimensions(dimension)));
2446     }
2447     VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2448     VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2449     if (start_index > limit_index) {
2450       return error(
2451           StrFormat("limit index (%d) must be greater or equal to "
2452                     "start index (%d) in slice with positive stride",
2453                     limit_index, start_index));
2454     }
2455     if (stride <= 0) {
2456       return InvalidArgument("Stride (%d) must be positive.", stride);
2457     }
2458     sizes.push_back((limit_index - start_index + stride - 1) / stride);
2459   }
2460 
2461   std::vector<bool> is_dynamic(arg.rank());
2462   for (int64 i = 0; i < arg.dimensions_size(); ++i) {
2463     // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2464     if (sizes[i] == 1) {
2465       continue;
2466     }
2467     is_dynamic[i] = arg.is_dynamic_dimension(i);
2468   }
2469 
2470   return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2471 }
2472 
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2473 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2474     const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2475     absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2476   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2477   auto number_of_indices = start_index_shapes.size();
2478   // TODO(b/118437727): Remove this path.
2479   if (!allow_scalar_indices ||
2480       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2481     if (number_of_indices != 1) {
2482       return InvalidArgument(
2483           "Dynamic slice should have exactly 1 index operand, has %d.",
2484           number_of_indices);
2485     }
2486 
2487     const Shape& start_indices_shape = start_index_shapes[0];
2488     VLOG(2) << StrFormat(
2489         "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2490         ShapeUtil::HumanString(operand_shape),
2491         ShapeUtil::HumanString(start_indices_shape),
2492         StrJoin(slice_sizes, ", "));
2493 
2494     TF_RETURN_IF_ERROR(
2495         ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2496 
2497     if (start_indices_shape.rank() != 1) {
2498       return InvalidArgument(
2499           "Dynamic slice start indices of rank %d must be rank1.",
2500           start_indices_shape.rank());
2501     }
2502 
2503     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2504       return InvalidArgument(
2505           "Dynamic slice start indices must be of integral type.");
2506     }
2507 
2508     const int64 start_num_dims = start_indices_shape.dimensions(0);
2509     if (operand_shape.rank() != start_num_dims) {
2510       return InvalidArgument(
2511           "Dynamic slice start number of dimensions %d (%s) must match rank "
2512           "%d of slice input (%s).",
2513           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2514           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2515     }
2516   } else {
2517     VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2518                          ShapeUtil::HumanString(operand_shape),
2519                          StrJoin(slice_sizes, ", "));
2520 
2521     if (operand_shape.rank() != number_of_indices) {
2522       return InvalidArgument(
2523           "Dynamic slice start number of dimensions %d must match rank "
2524           "%d of slice input (%s).",
2525           number_of_indices, operand_shape.rank(),
2526           ShapeUtil::HumanString(operand_shape));
2527     }
2528 
2529     if (number_of_indices > 0) {
2530       const Shape& first_index_shape = start_index_shapes[0];
2531       if (!ShapeUtil::IsScalar(first_index_shape)) {
2532         return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2533                                ShapeUtil::HumanString(first_index_shape));
2534       }
2535       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2536         return InvalidArgument(
2537             "Dynamic slice start indices must be of integral type.");
2538       }
2539       for (const Shape& index_shape : start_index_shapes) {
2540         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2541           return InvalidArgument(
2542               "Dynamic slice start indices must all have the same shape, got "
2543               "mismatching indices with shapes %s and %s.",
2544               ShapeUtil::HumanString(first_index_shape),
2545               ShapeUtil::HumanString(index_shape));
2546         }
2547       }
2548     }
2549   }
2550 
2551   if (slice_sizes.size() != operand_shape.rank()) {
2552     return InvalidArgument(
2553         "Dynamic slice index count does not match argument rank: %u vs %d.",
2554         slice_sizes.size(), operand_shape.rank());
2555   }
2556 
2557   for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2558     const int64 input_dim_size = operand_shape.dimensions(dim);
2559     const int64 slice_dim_size = slice_sizes[dim];
2560     if (slice_dim_size < 0) {
2561       return InvalidArgument("Negative size index to dynamic slice: %d.",
2562                              slice_dim_size);
2563     }
2564     if (slice_dim_size > input_dim_size) {
2565       return InvalidArgument(
2566           "Slice dim size %d greater than dynamic slice dimension: %d.",
2567           slice_dim_size, input_dim_size);
2568     }
2569     VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2570   }
2571 
2572   return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2573 }
2574 
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2575 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2576     const Shape& operand_shape, const Shape& update_shape,
2577     absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2578   TF_RETURN_IF_ERROR(
2579       ExpectArray(operand_shape, "operand of dynamic update slice"));
2580   TF_RETURN_IF_ERROR(
2581       ExpectArray(update_shape, "update of dynamic update slice"));
2582 
2583   auto number_of_indices = start_index_shapes.size();
2584   // TODO(b/118437727): Remove this path.
2585   if (!allow_scalar_indices ||
2586       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2587     if (number_of_indices != 1) {
2588       return InvalidArgument(
2589           "Dynamic update slice should have exactly 1 index operand, has %d.",
2590           number_of_indices);
2591     }
2592     const Shape& start_indices_shape = start_index_shapes[0];
2593     TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2594                                    "start indices of dynamic update slice"));
2595 
2596     VLOG(2) << StrFormat(
2597         "updating slice of shape %s at dynamic start_indices %s with update "
2598         "shape %s",
2599         ShapeUtil::HumanString(operand_shape),
2600         ShapeUtil::HumanString(start_indices_shape),
2601         ShapeUtil::HumanString(update_shape));
2602 
2603     if (start_indices_shape.rank() != 1) {
2604       return InvalidArgument(
2605           "Dynamic update slice start indices of rank %d must be rank1.",
2606           start_indices_shape.rank());
2607     }
2608 
2609     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2610       return InvalidArgument(
2611           "Dynamic update slice start indices must be of integral type.");
2612     }
2613 
2614     const int64 start_num_dims = start_indices_shape.dimensions(0);
2615     if (operand_shape.rank() != start_num_dims) {
2616       return InvalidArgument(
2617           "Dynamic update slice start number of dimensions %d (%s) must match "
2618           "rank %d of slice input (%s).",
2619           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2620           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2621     }
2622   } else {
2623     VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2624                          ShapeUtil::HumanString(operand_shape),
2625                          ShapeUtil::HumanString(update_shape));
2626 
2627     if (operand_shape.rank() != number_of_indices) {
2628       return InvalidArgument(
2629           "Dynamic update slice start number of dimensions %d must match "
2630           "rank %d of slice input (%s).",
2631           number_of_indices, operand_shape.rank(),
2632           ShapeUtil::HumanString(operand_shape));
2633     }
2634 
2635     if (number_of_indices > 0) {
2636       const Shape& first_index_shape = start_index_shapes[0];
2637       if (!ShapeUtil::IsScalar(first_index_shape)) {
2638         return InvalidArgument(
2639             "Dynamic update slice indices must be scalar, not %s.",
2640             ShapeUtil::HumanString(first_index_shape));
2641       }
2642       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2643         return InvalidArgument(
2644             "Dynamic update slice start indices must be of integral type.");
2645       }
2646       for (const Shape& index_shape : start_index_shapes) {
2647         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2648           return InvalidArgument(
2649               "Dynamic update slice start indices must all have the same "
2650               "shape, got mismatching indices with shapes %s and %s.",
2651               ShapeUtil::HumanString(first_index_shape),
2652               ShapeUtil::HumanString(index_shape));
2653         }
2654       }
2655     }
2656   }
2657 
2658   if (update_shape.rank() != operand_shape.rank()) {
2659     return InvalidArgument(
2660         "Dynamic update slice update rank does not match argument rank: "
2661         "%d vs %d.",
2662         update_shape.rank(), operand_shape.rank());
2663   }
2664 
2665   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2666                                                      update_shape)) {
2667     return InvalidArgument(
2668         "Dynamic update slice update element type does not match argument. "
2669         "operand.element_type: %s vs update.element_type: %s.",
2670         PrimitiveType_Name(operand_shape.element_type()),
2671         PrimitiveType_Name(update_shape.element_type()));
2672   }
2673 
2674   for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
2675     const int64 input_dim_size = operand_shape.dimensions(dim);
2676     const int64 update_dim_size = update_shape.dimensions(dim);
2677     if (update_dim_size < 0) {
2678       return InvalidArgument(
2679           "Size index %d to dynamic update slice must be >= 0.",
2680           update_dim_size);
2681     }
2682     if (update_dim_size > input_dim_size) {
2683       return InvalidArgument(
2684           "Update dim size %d greater than dynamic slice dimension: %d.",
2685           update_dim_size, input_dim_size);
2686     }
2687     VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2688   }
2689 
2690   auto result_shape = operand_shape;
2691 
2692   // If any of the operand shape is dynamic, the result dimension is also
2693   // dynamic.
2694   // If update shape is dynamic, only propagate dynamic dimension to result if
2695   // the update is a full update (update_shape[i] == operand_shape[i]).
2696   for (int64 i = 0; i < update_shape.rank(); ++i) {
2697     if (operand_shape.is_dynamic_dimension(i)) {
2698       result_shape.set_dynamic_dimension(i, true);
2699     }
2700 
2701     if (update_shape.is_dynamic_dimension(i) &&
2702         update_shape.dimensions(i) == operand_shape.dimensions(i)) {
2703       // When update/replace a full dimension, propagate dynamic dimension to
2704       // the result.
2705       result_shape.set_dynamic_dimension(i, true);
2706     }
2707   }
2708 
2709   return result_shape;
2710 }
2711 
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2712 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2713     const Shape& operand_shape, absl::Span<const int64> dimensions) {
2714   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2715   if (!AllUnique(dimensions)) {
2716     return InvalidArgument("a dimension number is duplicated in reverse");
2717   }
2718   for (int64 dimension : dimensions) {
2719     if (dimension >= operand_shape.rank() || dimension < 0) {
2720       return InvalidArgument(
2721           "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2722           dimension, ShapeUtil::HumanString(operand_shape));
2723     }
2724   }
2725   return operand_shape;
2726 }
2727 
InferGetTupleElementShape(const Shape & arg,int64 index)2728 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2729     const Shape& arg, int64 index) {
2730   if (!arg.IsTuple()) {
2731     return InvalidArgument(
2732         "Cannot infer shape: attempting to index into non-tuple: %s.",
2733         ShapeUtil::HumanString(arg));
2734   }
2735 
2736   if (index < 0 || index >= arg.tuple_shapes_size()) {
2737     return InvalidArgument(
2738         "Cannot infer shape: attempt to index out of tuple bounds: %d "
2739         ">= %d in shape %s.",
2740         index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2741   }
2742 
2743   return arg.tuple_shapes(index);
2744 }
2745 
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2746 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2747     const ProgramShape& condition, const ProgramShape& body,
2748     const Shape& init) {
2749   // Check the number of parameters for given computations.
2750   if (condition.parameters_size() != 1) {
2751     return InvalidArgument("Condition must take 1 arguments; got %d.",
2752                            condition.parameters_size());
2753   }
2754   if (body.parameters_size() != 1) {
2755     return InvalidArgument("Body must take 1 arguments; got %d.",
2756                            body.parameters_size());
2757   }
2758 
2759   auto shape_string = [&]() {
2760     return StrFormat(
2761         "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2762         ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2763   };
2764 
2765   // Check the shapes of computation parameters and return types.
2766   if (!ShapeUtil::Compatible(condition.result(),
2767                              ShapeUtil::MakeShape(PRED, {}))) {
2768     return InvalidArgument("Condition must return a boolean; got %s.",
2769                            shape_string());
2770   }
2771   if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2772       !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2773       !ShapeUtil::Compatible(body.result(), init)) {
2774     return InvalidArgument(
2775         "The parameter of condition and body, the result of the body, and init "
2776         "must all have the same shape; got %s.",
2777         shape_string());
2778   }
2779 
2780   return init;
2781 }
2782 
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2783 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2784     const Shape& branch_index,
2785     absl::Span<const ProgramShape> branch_computations,
2786     absl::Span<const Shape> branch_operands) {
2787   if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2788       !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2789     return InvalidArgument("branch_index must be bool or int32; got %s.",
2790                            ShapeUtil::HumanString(branch_index));
2791   }
2792   if (branch_index.element_type() == PRED) {
2793     TF_RET_CHECK(2 == branch_computations.size());
2794   } else {
2795     TF_RET_CHECK(!branch_computations.empty());
2796   }
2797   TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2798 
2799   for (int j = 0; j < branch_computations.size(); ++j) {
2800     if (branch_computations[j].parameters_size() != 1) {
2801       return InvalidArgument(
2802           "branch computation %d must take 1 argument; got %d.", j,
2803           branch_computations[j].parameters_size());
2804     }
2805     if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2806                                branch_operands[j])) {
2807       auto shape_string = [&]() {
2808         return StrFormat("operand: %s; computation: %s",
2809                          ShapeUtil::HumanString(branch_operands[j]),
2810                          ShapeUtil::HumanString(branch_computations[j]));
2811       };
2812       return InvalidArgument(
2813           "branch operand %d must match the shape of the only parameter of "
2814           "branch computation %d: got %s.",
2815           j, j, shape_string());
2816     }
2817 
2818     if (!ShapeUtil::Compatible(branch_computations[0].result(),
2819                                branch_computations[j].result())) {
2820       auto shape_string = [&]() {
2821         return StrFormat(
2822             "branch 0 computation result: %s; branch %d computation result: %s",
2823             ShapeUtil::HumanString(branch_computations[0].result()), j,
2824             ShapeUtil::HumanString(branch_computations[j].result()));
2825       };
2826       return InvalidArgument(
2827           "the result of branch 0 computation and branch %d computation must "
2828           "have the same shape: got %s.",
2829           j, shape_string());
2830     }
2831   }
2832   return branch_computations[0].result();
2833 }
2834 
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2835 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2836     const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2837   TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2838   for (int64 size : broadcast_sizes) {
2839     if (size < 0) {
2840       return InvalidArgument("Broadcast with negative dimension size %d.",
2841                              size);
2842     }
2843   }
2844 
2845   std::vector<int64> dimensions(operand.dimensions_size() +
2846                                 broadcast_sizes.size());
2847   std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2848   std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2849             dimensions.begin() + broadcast_sizes.size());
2850 
2851   Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2852   for (int64 i = 0; i < operand.dimensions_size(); ++i) {
2853     result.set_dynamic_dimension(broadcast_sizes.size() + i,
2854                                  operand.is_dynamic_dimension(i));
2855   }
2856   return result;
2857 }
2858 
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2859 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2860     const Shape& operand_shape, const Shape& output_shape,
2861     absl::Span<const int64> broadcast_dimensions) {
2862   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
2863   TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
2864   const int64 operand_rank = operand_shape.rank();
2865   const int64 output_rank = output_shape.rank();
2866   if (operand_rank > output_rank) {
2867     return InvalidArgument(
2868         "InDim style broadcast must be to an equal or higher ranked shape; "
2869         "operand rank: %lld; output rank: %lld",
2870         operand_rank, output_rank);
2871   }
2872   if (operand_rank != broadcast_dimensions.size()) {
2873     return InvalidArgument(
2874         "Size of broadcast_dimensions has to match operand's rank; operand "
2875         "rank: %lld, size of broadcast_dimensions %u.",
2876         operand_rank, broadcast_dimensions.size());
2877   }
2878   for (int64 i = 0; i < operand_rank; i++) {
2879     if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
2880       return InvalidArgument("Broadcast dimension %lld is out of bound",
2881                              broadcast_dimensions[i]);
2882     }
2883     if (operand_shape.dimensions(i) !=
2884             output_shape.dimensions(broadcast_dimensions[i]) &&
2885         operand_shape.dimensions(i) != 1) {
2886       return InvalidArgument(
2887           "Input dimension should be either 1 or equal to the output dimension "
2888           "it is broadcasting into; the %lldth operand dimension is %lld, the "
2889           "%lldth output dimension is %lld.",
2890           i, operand_shape.dimensions(i), broadcast_dimensions[i],
2891           output_shape.dimensions(broadcast_dimensions[i]));
2892     }
2893     if (operand_shape.is_dynamic_dimension(i) !=
2894         output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
2895       return InvalidArgument(
2896           "Broadcast input and output dynamism mismatch: %s and %s",
2897           operand_shape.ToString(), output_shape.ToString());
2898     }
2899     // Make sure the broadcast dimensions are listed in a strictly increasing
2900     // order.
2901     if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
2902       return InvalidArgument(
2903           "Broadcast dimensions order is wrong: %d comes after %d.",
2904           broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
2905     }
2906   }
2907 
2908   return output_shape;
2909 }
2910 
InferDynamicReshapeShape(const Shape & operand,absl::Span<const Shape * const> dim_size_shapes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)2911 /* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
2912     const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
2913     absl::Span<const int64> new_size_bounds,
2914     const std::vector<bool>& dims_are_dynamic) {
2915   if (new_size_bounds.size() != dims_are_dynamic.size()) {
2916     return InvalidArgument(
2917         "DynamicReshape has to have the same number of elements in new_sizes "
2918         "(%d) and dims_are_dynamic (%d)",
2919         new_size_bounds.size(), dims_are_dynamic.size());
2920   }
2921 
2922   for (const Shape* dim_size_shape : dim_size_shapes) {
2923     if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
2924       return InvalidArgument(
2925           "DynamicReshape's dim size has to be scalar S32, got (%s): ",
2926           dim_size_shape->ToString());
2927     }
2928   }
2929 
2930   Shape inferred_shape = ShapeUtil::MakeShape(
2931       operand.element_type(), new_size_bounds, dims_are_dynamic);
2932   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2933     return InvalidArgument(
2934         "Reshape operation has mismatched element counts: from=%d (%s) "
2935         "to=%d (%s).",
2936         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2937         ShapeUtil::ElementsIn(inferred_shape),
2938         ShapeUtil::HumanString(inferred_shape));
2939   }
2940   return inferred_shape;
2941 }
2942 
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)2943 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2944     const Shape& operand, absl::Span<const int64> dimensions,
2945     absl::Span<const int64> new_sizes, int64 inferred_dimension) {
2946   TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
2947 
2948   Shape inferred_shape =
2949       ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2950   VLOG(3) << "Reshape inferred shape: "
2951           << ShapeUtil::HumanString(inferred_shape);
2952 
2953   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2954     return InvalidArgument(
2955         "Reshape operation has mismatched element counts: from=%d (%s) "
2956         "to=%d (%s).",
2957         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2958         ShapeUtil::ElementsIn(inferred_shape),
2959         ShapeUtil::HumanString(inferred_shape));
2960   }
2961 
2962   std::vector<int64> indices(operand.rank());
2963   std::iota(indices.begin(), indices.end(), 0);
2964   if (dimensions.size() != operand.rank() ||
2965       !std::is_permutation(dimensions.begin(), dimensions.end(),
2966                            indices.begin())) {
2967     return InvalidArgument(
2968         "Reshape dimensions [%s] are not a permutation of the operand "
2969         "dimensions (operand shape is %s).",
2970         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2971   }
2972 
2973   // Propagate dynamic dimension.
2974   auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
2975   for (int64 input_dim = 0; input_dim < operand.rank(); ++input_dim) {
2976     if (!operand.is_dynamic_dimension(input_dim)) {
2977       continue;
2978     }
2979 
2980     string reshape_debug_str = absl::StrFormat(
2981         "output: %s, input: %s, input_dim: "
2982         "%lld",
2983         ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
2984         input_dim);
2985 
2986     int64 input_dim_start = -1;
2987     int64 input_dim_end = -1;
2988     int64 output_dim_start = -1;
2989     int64 output_dim_end = -1;
2990     // Find common_factors that the input_dim belongs to.
2991     for (int64 i = 0; i < common_factors.size() - 1; ++i) {
2992       auto start = common_factors[i];
2993       auto end = common_factors[i + 1];
2994       if (input_dim >= start.first && input_dim < end.first) {
2995         input_dim_start = start.first;
2996         input_dim_end = end.first;
2997         output_dim_start = start.second;
2998         output_dim_end = end.second;
2999         break;
3000       }
3001     }
3002     if ((input_dim_end - input_dim_start) > 1 &&
3003         (output_dim_end - output_dim_start) > 1) {
3004       // We don't support the case when a dynamic dimension is both combined
3005       // with and splitted into other dimensions:
3006       //
3007       //  [x, yz]
3008       //     | Reshape
3009       //  [xy, z]
3010       return Unimplemented(
3011           "Dynamic input dimension to reshape that is both splitted and "
3012           "combined is not supported: %s",
3013           reshape_debug_str);
3014     }
3015 
3016     for (auto common_factor : common_factors) {
3017       //
3018       // For reshapes like:
3019       //  [<=5]
3020       //    | Reshape
3021       //  [1, 5]
3022       //
3023       //  The input dynamic dimension can go into either dynamic dimensions.
3024       //  However, the return value of common factors only returns
3025       //  input: 5
3026       //  output: 5
3027       //
3028       //  We need to expand common factor to include degenerated output
3029       //  dimensions:
3030       //  input: 5
3031       //  output: 1, 5
3032       //
3033       //  such that in the logic later on we can consider both dimensions as
3034       //  candidate.
3035       if (common_factor.first == input_dim_start) {
3036         output_dim_start = std::min(output_dim_start, common_factor.second);
3037       }
3038       if (common_factor.first == input_dim_end) {
3039         output_dim_end = std::max(output_dim_end, common_factor.second);
3040       }
3041     }
3042 
3043     // Calculate output dynamic reshape dimension.
3044     int64 output_dynamic_dimension = -1;
3045 
3046     if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
3047       // If dynamic dimension is size 1, it can only be most-major or
3048       // most-minor.
3049       if (input_dim == 0) {
3050         output_dynamic_dimension = 0;
3051       }
3052       if (input_dim == operand.rank() - 1) {
3053         output_dynamic_dimension = new_sizes.size() - 1;
3054       }
3055 
3056       if (output_dynamic_dimension == -1) {
3057         return Unimplemented(
3058             "Dynamic degenerated dimension that's not most-minor nor "
3059             "most-major is not supported: %s",
3060             reshape_debug_str);
3061       }
3062     }
3063 
3064     if (output_dynamic_dimension == -1 &&
3065         output_dim_end - output_dim_start == 1) {
3066       // Only one possible output dimension.
3067       output_dynamic_dimension = output_dim_start;
3068     }
3069     if (output_dynamic_dimension == -1 &&
3070         output_dim_end - output_dim_start > 1) {
3071       // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
3072       output_dynamic_dimension = inferred_dimension;
3073     }
3074 
3075     if (output_dynamic_dimension != -1) {
3076       // TODO(yunxing): Turn this into a CHECK.
3077       inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
3078     } else {
3079       std::vector<int64> output_non_degenerated;
3080       for (int64 i = output_dim_start; i < output_dim_end; ++i) {
3081         if (new_sizes[i] != 1) {
3082           output_non_degenerated.push_back(i);
3083         }
3084       }
3085       if (output_non_degenerated.size() == 1) {
3086         inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
3087       }
3088     }
3089   }
3090 
3091   return inferred_shape;
3092 }
3093 
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)3094 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
3095     const Shape& operand, absl::Span<const int64> dimensions) {
3096   TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
3097 
3098   if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
3099     return InvalidArgument(
3100         "Transpose dimensions [%s] are not a permutation of the operand "
3101         "dimensions (operand shape is %s).",
3102         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3103   }
3104 
3105   // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
3106   // we need output[i]=input[dimensions[i]] which is
3107   // Permute(Inverse(dimensions),input).
3108   return ShapeUtil::PermuteDimensions(dimensions, operand);
3109 }
3110 
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)3111 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
3112     const Shape& min, const Shape& operand, const Shape& max) {
3113   TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
3114   TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
3115   TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
3116 
3117   if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
3118       !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
3119     return InvalidArgument(
3120         "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
3121         ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
3122   }
3123   return operand;
3124 }
3125 
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3126 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
3127     const Shape& pred, const Shape& on_true, const Shape& on_false) {
3128   TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
3129   TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
3130   TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
3131 
3132   if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
3133     return InvalidArgument(
3134         "Operands to select must be the same shape; got %s and %s.",
3135         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3136   }
3137   if (pred.element_type() != PRED) {
3138     return InvalidArgument(
3139         "Select's pred operand must have PRED element type; got %s.",
3140         ShapeUtil::HumanString(pred));
3141   }
3142   if (!Shape::Equal()
3143            .IgnoreElementType()
3144            .IgnoreLayout()
3145            .IgnoreDynamicDimension()(pred, on_true)) {
3146     return InvalidArgument(
3147         "Operands to select and predicate must be the same shape; got %s and "
3148         "%s.",
3149         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3150   }
3151 
3152   return ShapeUtil::ChangeElementType(
3153       pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3154 }
3155 
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3156 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
3157     const Shape& pred, const Shape& on_true, const Shape& on_false) {
3158   // Select only defines the top-level buffer, so if it's a tuple, the two
3159   // input must match exactly.
3160   if (!ShapeUtil::Compatible(on_true, on_false)) {
3161     return InvalidArgument(
3162         "Operands to tuple-select must be the same shape; got %s and %s.",
3163         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3164   }
3165   if (pred.element_type() != PRED) {
3166     return InvalidArgument(
3167         "TupleSelect's pred operand must have PRED element type; got %s.",
3168         ShapeUtil::HumanString(pred));
3169   }
3170   if (!ShapeUtil::IsScalar(pred)) {
3171     return InvalidArgument(
3172         "TupleSelect operation with non-scalar predicate: %s.",
3173         ShapeUtil::HumanString(pred));
3174   }
3175   return on_true;
3176 }
3177 
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3178 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3179     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3180   // The applied function's arity equals the number of arguments.
3181   if (arg_shapes.size() != to_apply.parameters_size()) {
3182     string computation_signature = ShapeUtil::HumanString(to_apply);
3183     string argument_shapes =
3184         StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
3185           absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3186         });
3187     return InvalidArgument(
3188         "Call applied function arity must match number of arguments; got: "
3189         "arity: %d, arguments: %u; computation signature: %s; argument "
3190         "shapes: [%s].",
3191         to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3192         argument_shapes);
3193   }
3194 
3195   // All arguments must be compatible with the program shape.
3196   for (int i = 0; i < arg_shapes.size(); ++i) {
3197     const Shape& arg_shape = *arg_shapes[i];
3198     const Shape& param_shape = to_apply.parameters(i);
3199     if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3200       return InvalidArgument(
3201           "Call parameter must match argument; got parameter %d shape: %s, "
3202           "argument shape: %s.",
3203           i, ShapeUtil::HumanString(param_shape),
3204           ShapeUtil::HumanString(arg_shape));
3205     }
3206   }
3207 
3208   return to_apply.result();
3209 }
3210 
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3211 static Status ValidateGatherDimensionNumbers(
3212     const Shape& input_shape, absl::Span<const int64> start_indices_shape,
3213     const GatherDimensionNumbers& dim_numbers) {
3214   if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3215     return InvalidArgument(
3216         "Output window dimensions in gather op must be ascending; got: %s.",
3217         StrJoin(dim_numbers.offset_dims(), ", "));
3218   }
3219 
3220   if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3221       dim_numbers.offset_dims().end()) {
3222     return InvalidArgument(
3223         "Output window dimensions in gather op must not repeat; got: %s.",
3224         StrJoin(dim_numbers.offset_dims(), ", "));
3225   }
3226 
3227   const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
3228   const int64 output_shape_rank =
3229       output_offset_dim_count + start_indices_shape.size() - 1;
3230 
3231   for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3232     int64 offset_dim = dim_numbers.offset_dims(i);
3233     if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3234       return InvalidArgument(
3235           "Offset dimension %d in gather op is out of bounds; got %d, but "
3236           "should "
3237           "have been in [0,%d).",
3238           i, offset_dim, output_shape_rank);
3239     }
3240   }
3241 
3242   if (dim_numbers.start_index_map_size() !=
3243       start_indices_shape[dim_numbers.index_vector_dim()]) {
3244     return InvalidArgument(
3245         "Gather op has %d elements in start_index_map and the "
3246         "bound of dimension index_vector_dim=%d of start_indices is "
3247         "%d. These two numbers must be equal.",
3248         dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3249         start_indices_shape[dim_numbers.index_vector_dim()]);
3250   }
3251 
3252   for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3253     int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3254     if (operand_dim_for_start_index_i < 0 ||
3255         operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3256       return InvalidArgument(
3257           "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3258           input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3259     }
3260   }
3261 
3262   std::vector<int64> sorted_start_index_map(
3263       dim_numbers.start_index_map().begin(),
3264       dim_numbers.start_index_map().end());
3265 
3266   absl::c_sort(sorted_start_index_map);
3267 
3268   if (absl::c_adjacent_find(sorted_start_index_map) !=
3269       sorted_start_index_map.end()) {
3270     return InvalidArgument(
3271         "Repeated dimensions are not allowed in start_index_map; "
3272         "got: %s.",
3273         StrJoin(dim_numbers.start_index_map(), ", "));
3274   }
3275 
3276   for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3277     if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3278       return InvalidArgument(
3279           "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3280           "%d), got: %d.",
3281           input_shape.dimensions_size(), collapsed_dim);
3282     }
3283   }
3284 
3285   if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3286     return InvalidArgument(
3287         "collapsed_slice_dims in gather op must be sorted; got: %s",
3288         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3289   }
3290 
3291   if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3292       dim_numbers.collapsed_slice_dims().end()) {
3293     return InvalidArgument(
3294         "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3295         "got: %s.",
3296         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3297   }
3298 
3299   return Status::OK();
3300 }
3301 
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)3302 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3303     const Shape& input_shape, const Shape& start_indices_shape,
3304     const GatherDimensionNumbers& gather_dim_numbers,
3305     absl::Span<const int64> slice_sizes) {
3306   TF_RETURN_IF_ERROR(
3307       ExpectArray(input_shape, "input tensor operand of gather op"));
3308   TF_RETURN_IF_ERROR(
3309       ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3310 
3311   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3312     return InvalidArgument(
3313         "Gather indices parameter must be an integral tensor; got %s.",
3314         ShapeUtil::HumanString(start_indices_shape));
3315   }
3316 
3317   // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3318   // index_vector_dim is rank(P).  The bounds of this expanded shape is
3319   // stored in expanded_start_indices_shape.
3320 
3321   if (start_indices_shape.dimensions_size() <
3322           gather_dim_numbers.index_vector_dim() ||
3323       gather_dim_numbers.index_vector_dim() < 0) {
3324     return InvalidArgument(
3325         "Gather index leaf dimension must be within [0, rank(start_indices) + "
3326         "1). rank(start_indices) is %d and gather index leaf dimension is "
3327         "%d.",
3328         start_indices_shape.dimensions_size(),
3329         gather_dim_numbers.index_vector_dim());
3330   }
3331 
3332   std::vector<int64> expanded_start_indices_shape;
3333   // Also tracks if an output dimension is dynamic.
3334   std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3335   expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3336   expanded_start_indices_shape_dynamic_dimensions.reserve(
3337       start_indices_shape.dimensions_size());
3338   absl::c_copy(start_indices_shape.dimensions(),
3339                std::back_inserter(expanded_start_indices_shape));
3340   absl::c_copy(
3341       start_indices_shape.dynamic_dimensions(),
3342       std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3343   if (expanded_start_indices_shape.size() ==
3344       gather_dim_numbers.index_vector_dim()) {
3345     expanded_start_indices_shape.push_back(1);
3346     expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3347   }
3348 
3349   TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3350       input_shape, expanded_start_indices_shape, gather_dim_numbers));
3351 
3352   if (slice_sizes.size() != input_shape.dimensions_size()) {
3353     return InvalidArgument(
3354         "Gather op must have one slice size for every input dimension; got: "
3355         "len(slice_sizes)=%lu, input_shape.rank=%d.",
3356         slice_sizes.size(), input_shape.dimensions_size());
3357   }
3358 
3359   if (slice_sizes.size() !=
3360       gather_dim_numbers.offset_dims_size() +
3361           gather_dim_numbers.collapsed_slice_dims_size()) {
3362     return InvalidArgument(
3363         "All components of the offset index in a gather op must either be a "
3364         "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3365         "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3366         slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3367         StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3368   }
3369 
3370   for (int i = 0; i < slice_sizes.size(); i++) {
3371     int64 slice_size = slice_sizes[i];
3372     int64 corresponding_input_size = input_shape.dimensions(i);
3373     if (slice_size < 0 || slice_size > corresponding_input_size) {
3374       return InvalidArgument(
3375           "Slice size at index %d in gather op is out of range, must be "
3376           "within [0, %d), got %d.",
3377           i, corresponding_input_size + 1, slice_size);
3378     }
3379   }
3380 
3381   for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3382     if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3383       return InvalidArgument(
3384           "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3385           "is %d for index %d at position %d.",
3386           slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3387           gather_dim_numbers.collapsed_slice_dims(i), i);
3388     }
3389   }
3390 
3391   int64 result_rank = gather_dim_numbers.offset_dims_size() +
3392                       (expanded_start_indices_shape.size() - 1);
3393   int64 offset_dims_seen = 0;
3394   int64 gather_dims_seen = 0;
3395   std::vector<int64> output_dim_bounds;
3396   output_dim_bounds.reserve(result_rank);
3397 
3398   std::vector<bool> output_dim_is_dynamic;
3399   output_dim_is_dynamic.reserve(result_rank);
3400   for (int64 i = 0; i < result_rank; i++) {
3401     int64 current_bound;
3402     bool dim_dynamic = false;
3403     bool is_window_index =
3404         absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3405     if (is_window_index) {
3406       while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3407                                    offset_dims_seen)) {
3408         offset_dims_seen++;
3409       }
3410       // Gathering an entire dynamic dimension creates dynamic dimension.
3411       //
3412       // e.g.,:
3413       //
3414       // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3415       //
3416       // creates
3417       //
3418       // [<=2, 1]
3419       if (slice_sizes[offset_dims_seen] ==
3420           input_shape.dimensions(offset_dims_seen)) {
3421         dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3422       }
3423       current_bound = slice_sizes[offset_dims_seen++];
3424     } else {
3425       if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3426         gather_dims_seen++;
3427       }
3428       // Forward dynamic dimensions from indices.
3429       dim_dynamic =
3430           expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3431 
3432       current_bound = expanded_start_indices_shape[gather_dims_seen++];
3433     }
3434     output_dim_is_dynamic.push_back(dim_dynamic);
3435     output_dim_bounds.push_back(current_bound);
3436   }
3437 
3438   return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3439                               output_dim_is_dynamic);
3440 }
3441 
3442 namespace {
3443 
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3444 Status ValidateScatterDimensionNumbers(
3445     const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3446     const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3447   // Validate update_window_dims in ScatterDimensionNumbers.
3448   if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3449     return InvalidArgument(
3450         "update_window_dims in scatter op must be sorted; got: %s.",
3451         StrJoin(dim_numbers.update_window_dims(), ", "));
3452   }
3453   if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3454       dim_numbers.update_window_dims().end()) {
3455     return InvalidArgument(
3456         "update_window_dims in scatter op must not repeat; got: %s.",
3457         StrJoin(dim_numbers.update_window_dims(), ", "));
3458   }
3459   const int64 updates_rank = updates_shape.rank();
3460   for (int64 window_dim : dim_numbers.update_window_dims()) {
3461     if (window_dim < 0 || window_dim >= updates_rank) {
3462       return InvalidArgument(
3463           "Invalid update_window_dims set in scatter op; valid range is [0, "
3464           "%d). got: %d.",
3465           updates_rank, window_dim);
3466     }
3467   }
3468 
3469   // Validate inserted_window_dims in ScatterDimensionNumbers.
3470   if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3471     return InvalidArgument(
3472         "inserted_window_dims in scatter op must be sorted; got: %s.",
3473         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3474   }
3475   if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3476       dim_numbers.inserted_window_dims().end()) {
3477     return InvalidArgument(
3478         "inserted_window_dims in scatter op must not repeat; got: %s.",
3479         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3480   }
3481   for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
3482     if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3483       return InvalidArgument(
3484           "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3485           "%d), got: %d.",
3486           operand_shape.dimensions_size(), inserted_dim);
3487     }
3488   }
3489 
3490   // Validate window size.
3491   auto window_size = dim_numbers.update_window_dims_size() +
3492                      dim_numbers.inserted_window_dims_size();
3493   if (window_size != operand_shape.rank()) {
3494     return InvalidArgument(
3495         "Scatter op has window of size %d; doesn't match operand of rank %d.",
3496         window_size, operand_shape.rank());
3497   }
3498 
3499   // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3500   if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3501       scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3502     return InvalidArgument(
3503         "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3504         "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3505         "These two numbers must be equal.",
3506         dim_numbers.scatter_dims_to_operand_dims_size(),
3507         dim_numbers.index_vector_dim(),
3508         scatter_indices_shape[dim_numbers.index_vector_dim()]);
3509   }
3510   for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3511     int64 scatter_dim_to_operand_dim =
3512         dim_numbers.scatter_dims_to_operand_dims(i);
3513     if (scatter_dim_to_operand_dim < 0 ||
3514         scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3515       return InvalidArgument(
3516           "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3517           "got: %d->%d.",
3518           operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3519     }
3520   }
3521   std::vector<int64> sorted_scatter_dims_to_operand_dims(
3522       dim_numbers.scatter_dims_to_operand_dims().begin(),
3523       dim_numbers.scatter_dims_to_operand_dims().end());
3524   absl::c_sort(sorted_scatter_dims_to_operand_dims);
3525   if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3526       sorted_scatter_dims_to_operand_dims.end()) {
3527     return InvalidArgument(
3528         "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3529         "got: %s.",
3530         StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3531   }
3532 
3533   return Status::OK();
3534 }
3535 
3536 }  // namespace
3537 
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3538 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3539     const Shape& operand_shape, const Shape& scatter_indices_shape,
3540     const Shape& updates_shape, const ProgramShape& to_apply_shape,
3541     const ScatterDimensionNumbers& scatter_dim_numbers) {
3542   TF_RETURN_IF_ERROR(
3543       ExpectArray(operand_shape, "operand tensor of scatter op"));
3544   TF_RETURN_IF_ERROR(
3545       ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3546   TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3547 
3548   if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3549     return InvalidArgument(
3550         "Scatter indices parameter must be an integral tensor; got %s.",
3551         ShapeUtil::HumanString(scatter_indices_shape));
3552   }
3553 
3554   if (scatter_indices_shape.dimensions_size() <
3555           scatter_dim_numbers.index_vector_dim() ||
3556       scatter_dim_numbers.index_vector_dim() < 0) {
3557     return InvalidArgument(
3558         "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3559         " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3560         "is %d.",
3561         scatter_indices_shape.dimensions_size(),
3562         scatter_dim_numbers.index_vector_dim());
3563   }
3564 
3565   // Check if the update computation has a proper shape as a reduction.
3566   const Shape init_value_shape =
3567       ShapeUtil::MakeShape(operand_shape.element_type(), {});
3568   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3569                                         {updates_shape.element_type()},
3570                                         /*inputs=*/1));
3571 
3572   std::vector<int64> expanded_scatter_indices_shape =
3573       SpanToVector(scatter_indices_shape.dimensions());
3574   if (expanded_scatter_indices_shape.size() ==
3575       scatter_dim_numbers.index_vector_dim()) {
3576     expanded_scatter_indices_shape.push_back(1);
3577   }
3578 
3579   int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3580                                 scatter_dim_numbers.update_window_dims_size();
3581   if (updates_shape.rank() != expected_updates_rank) {
3582     return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3583                            expected_updates_rank, updates_shape.rank());
3584   }
3585 
3586   TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3587       operand_shape, expanded_scatter_indices_shape, updates_shape,
3588       scatter_dim_numbers));
3589 
3590   int64 inserted_dims_seen = 0;
3591   std::vector<int64> max_update_slice_sizes;
3592   for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3593     if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3594         scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3595       ++inserted_dims_seen;
3596     } else {
3597       max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3598     }
3599   }
3600   for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3601     auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3602     if (updates_shape.dimensions(update_window_dim) >
3603         max_update_slice_sizes[i]) {
3604       return InvalidArgument(
3605           "Bounds of the window dimensions of updates must not exceed the "
3606           "bounds of the corresponding dimensions of operand. For dimension "
3607           "%d, updates bound is %d, operand bound is %d.",
3608           update_window_dim, updates_shape.dimensions(update_window_dim),
3609           max_update_slice_sizes[i]);
3610     }
3611   }
3612 
3613   int64 scatter_dims_seen = 0;
3614   for (int64 i = 0; i < updates_shape.rank(); ++i) {
3615     bool is_update_window_dim =
3616         absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3617     if (is_update_window_dim) {
3618       continue;
3619     }
3620     if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3621       ++scatter_dims_seen;
3622     }
3623     if (updates_shape.dimensions(i) !=
3624         expanded_scatter_indices_shape[scatter_dims_seen]) {
3625       return InvalidArgument(
3626           "Bounds of the scatter dimensions of updates must be same as the "
3627           "bounds of the corresponding dimensions of scatter indices. For "
3628           "scatter dimension %d, updates bound is %d, scatter_indices "
3629           "bound is %d.",
3630           i, updates_shape.dimensions(i),
3631           expanded_scatter_indices_shape[scatter_dims_seen]);
3632     }
3633     ++scatter_dims_seen;
3634   }
3635 
3636   return operand_shape;
3637 }
3638 
3639 }  // namespace xla
3640