• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/math/math_util.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/protobuf.h"
40 
41 namespace xla {
42 namespace {
43 
44 using absl::StrFormat;
45 using absl::StrJoin;
46 
47 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)48 bool AllUnique(absl::Span<const int64> slice) {
49   return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
50 }
51 
ExpectArray(const Shape & shape,absl::string_view op_type)52 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
53   if (!shape.IsArray()) {
54     return InvalidArgument("Expected array argument for %s, but got %s.",
55                            string(op_type), ShapeUtil::HumanString(shape));
56   }
57   return Status::OK();
58 }
59 
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64 inputs)60 Status VerifyReducerShape(const ProgramShape& reducer_shape,
61                           absl::Span<const Shape* const> init_value_shapes,
62                           absl::Span<const PrimitiveType> input_element_types,
63                           int64 inputs) {
64   if (reducer_shape.parameters_size() != inputs * 2) {
65     return InvalidArgument(
66         "Reduction function must take %d parameters, but "
67         "takes %d parameter(s).",
68         inputs * 2, reducer_shape.parameters_size());
69   }
70 
71   const Shape& accumulator_shape = reducer_shape.result();
72   std::vector<const Shape*> accumulator_subshapes;
73   if (accumulator_shape.IsArray()) {
74     if (inputs != 1) {
75       return InvalidArgument(
76           "Reduction function must produce a tuple with %d elements, but "
77           "produces a scalar",
78           inputs);
79     }
80     accumulator_subshapes.push_back(&accumulator_shape);
81   } else if (accumulator_shape.IsTuple()) {
82     if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
83       return InvalidArgument(
84           "Reduction function must produce a tuple with %d elements, but has "
85           "%d elements",
86           inputs, ShapeUtil::TupleElementCount(accumulator_shape));
87     }
88     for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
89       accumulator_subshapes.push_back(&element_shape);
90     }
91   } else {
92     return InvalidArgument(
93         "Reduction function must produce a scalar or tuple of scalars, but has "
94         "shape: %s",
95         ShapeUtil::HumanString(accumulator_shape));
96   }
97 
98   for (const Shape* element_shape : accumulator_subshapes) {
99     if (element_shape->rank() != 0) {
100       return InvalidArgument(
101           "Reduction function must return a scalar or tuple of scalars but "
102           "returns shape: %s",
103           ShapeUtil::HumanString(accumulator_shape));
104     }
105   }
106 
107   for (int64 i = 0; i < inputs; ++i) {
108     // Check that the accumulator can be passed in as the first argument.
109     // Note: comparing here and below with Compatible since we don't care about
110     // layout in scalars - see b/26668201 for a longer-term vision.
111     if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
112                                reducer_shape.parameters(i))) {
113       return InvalidArgument(
114           "Reduction function's %d-th parameter shape differs from the "
115           "result shape: %s vs %s",
116           i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
117           ShapeUtil::HumanString(*accumulator_subshapes[i]));
118     }
119     // Check that init_value's shapes are suitable for reducer_shape.
120     if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
121                                                   *init_value_shapes[i])) {
122       return InvalidArgument(
123           "Reduction function's accumulator shape at index %d differs from "
124           "the init_value shape: %s vs %s",
125           i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
126           ShapeUtil::HumanString(*init_value_shapes[i]));
127     }
128     // Check that the inputs can be passed in as the non-accumulator arguments.
129     const Shape input_element_shape =
130         ShapeUtil::MakeShape(input_element_types[i], {});
131     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
132             input_element_shape, reducer_shape.parameters(inputs + i))) {
133       return InvalidArgument(
134           "Reduction function's %d-th parameter shape differs from the "
135           "input type element type: %s vs %s",
136           inputs + i,
137           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
138           ShapeUtil::HumanString(input_element_shape));
139     }
140     // Check that the accumulator and inputs to the reducer function match.
141     // If the accumulator is scalar, it must have the same type as the inputs
142     // (up to fp precision). If it is a tuple, then the k-th element of the
143     // tuple must have the same type as the K-th input (again, up to fp
144     // precision.)
145     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
146             *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
147       return InvalidArgument(
148           "Reduction function's %d-th parameter shape must "
149           "match the result shape, but got %s vs %s.",
150           inputs + i,
151           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
152           ShapeUtil::HumanString(*accumulator_subshapes[i]));
153     }
154   }
155 
156   return Status::OK();
157 }
158 
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)159 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
160                                        const Window& window,
161                                        PrimitiveType element_type,
162                                        bool allow_negative_padding) {
163   if (window.dimensions_size() != base_shape.rank()) {
164     return InvalidArgument(
165         "Window has dimension %d but base shape has dimension %d.",
166         window.dimensions_size(), base_shape.rank());
167   }
168 
169   std::vector<int64> output_dimensions(window.dimensions_size());
170   std::vector<bool> output_is_dynamic(window.dimensions_size());
171   for (int64 i = 0; i < window.dimensions_size(); ++i) {
172     const auto& dim = window.dimensions(i);
173     if (dim.size() <= 0) {
174       return InvalidArgument("Window %s has a non-positive dimension.",
175                              window.DebugString());
176     }
177     if (dim.stride() <= 0) {
178       return InvalidArgument("Window %s has a non-positive stride.",
179                              window.DebugString());
180     }
181     if (!allow_negative_padding && dim.padding_low() < 0) {
182       return InvalidArgument("Window %s has a negative low padding.",
183                              window.DebugString());
184     }
185     if (!allow_negative_padding && dim.padding_high() < 0) {
186       return InvalidArgument("Window %s has a negative high padding.",
187                              window.DebugString());
188     }
189     if (dim.base_dilation() < 1) {
190       return InvalidArgument(
191           "Window %s has a non-positive base area dilation factor.",
192           window.DebugString());
193     }
194     if (dim.window_dilation() < 1) {
195       return InvalidArgument(
196           "Window %s has a non-positive window dilation factor.",
197           window.DebugString());
198     }
199 
200     if (base_shape.is_dynamic_dimension(i) &&
201         !window_util::IsTrivialWindowDimension(dim)) {
202       return Unimplemented(
203           "Dynamic shape is not supported for non trivial window: %s",
204           window_util::ToString(window));
205     }
206 
207     const int64 dilated_base = window_util::DilatedBound(
208         ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
209     const int64 padded_dilated_base =
210         dim.padding_low() + dilated_base + dim.padding_high();
211     const int64 dilated_window =
212         window_util::DilatedBound(dim.size(), dim.window_dilation());
213 
214     output_dimensions[i] = window_util::StridedBound(
215         padded_dilated_base, dilated_window, dim.stride());
216     output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
217   }
218 
219   return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
220                                        output_is_dynamic);
221 }
222 
223 }  // namespace
224 
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)225 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
226     HloOpcode opcode, const HloInstruction* operand) {
227   return InferUnaryOpShape(opcode, operand->shape());
228 }
229 
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)230 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
231     HloOpcode opcode, const Shape& shape) {
232   // There is no copy operation at the proto level, so handle copy explicitly.
233   // A domain shape is the same as the input one.
234   if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
235     return shape;
236   }
237 
238   TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
239 
240   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
241   switch (opcode) {
242     case HloOpcode::kFloor:
243     case HloOpcode::kCeil:
244     case HloOpcode::kRoundNearestAfz:
245       if (!ShapeUtil::ElementIsFloating(shape)) {
246         return InvalidArgument(
247             "Expected element type in shape to be floating for %s operation; "
248             "got %s.",
249             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
250       }
251       return shape;
252     case HloOpcode::kCos:
253     case HloOpcode::kSin:
254     case HloOpcode::kExp:
255     case HloOpcode::kExpm1:
256     case HloOpcode::kLog:
257     case HloOpcode::kLog1p:
258     case HloOpcode::kRsqrt:
259     case HloOpcode::kSqrt:
260     case HloOpcode::kTanh:
261       if (!ShapeUtil::ElementIsFloating(shape) &&
262           !ShapeUtil::ElementIsComplex(shape)) {
263         return InvalidArgument(
264             "Expected element type in shape to be floating or complex for %s "
265             "operation; got %s.",
266             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
267       }
268       return shape;
269     case HloOpcode::kReal:
270     case HloOpcode::kImag:
271       if (ShapeUtil::ElementIsComplex(shape)) {
272         return ShapeUtil::ComplexComponentShape(shape);
273       } else if (ShapeUtil::ElementIsFloating(shape)) {
274         return shape;
275       } else {
276         return InvalidArgument(
277             "Expected element type in shape to be floating or complex for "
278             "%s operation; got %s.",
279             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
280       }
281     case HloOpcode::kAbs:
282       if (ShapeUtil::ElementIsComplex(shape)) {
283         return ShapeUtil::ChangeElementType(
284             shape, primitive_util::ComplexComponentType(shape.element_type()));
285       } else if (ShapeUtil::ElementIsSigned(shape)) {
286         return shape;
287       } else {
288         return InvalidArgument(
289             "Expected element type in shape to be floating or complex for "
290             "%s operation; got %s.",
291             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
292       }
293     case HloOpcode::kClz:
294       if (!ShapeUtil::ElementIsIntegral(shape)) {
295         return InvalidArgument(
296             "Expected an integral element type in argument to Clz "
297             "operation; got %s.",
298             PrimitiveType_Name(shape.element_type()));
299       }
300       return shape;
301     case HloOpcode::kNegate:
302       if (!ShapeUtil::ElementIsIntegral(shape) &&
303           !ShapeUtil::ElementIsFloating(shape) &&
304           !ShapeUtil::ElementIsComplex(shape)) {
305         return InvalidArgument(
306             "Expected element type in shape to be integral, floating or "
307             "complex for %s operation; got %s.",
308             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
309       }
310       return shape;
311     case HloOpcode::kPopulationCount:
312       if (!ShapeUtil::ElementIsIntegral(shape)) {
313         return InvalidArgument(
314             "Expected an integral element type in argument to PopulationCount "
315             "operation; got %s.",
316             PrimitiveType_Name(shape.element_type()));
317       }
318       return shape;
319     case HloOpcode::kSign:
320       if (!ShapeUtil::ElementIsSigned(shape) &&
321           !ShapeUtil::ElementIsComplex(shape)) {
322         return InvalidArgument(
323             "Expected element type in shape to be signed or complex for "
324             "%s operation; got %s.",
325             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
326       }
327       return shape;
328 
329     case HloOpcode::kNot:
330       if (shape.element_type() != PRED &&
331           !primitive_util::IsIntegralType(shape.element_type())) {
332         return InvalidArgument(
333             "Expected pred or an integral element type in argument to Not "
334             "operation; got %s.",
335             PrimitiveType_Name(shape.element_type()));
336       }
337       return shape;
338 
339     case HloOpcode::kIsFinite:
340       if (!ShapeUtil::ElementIsFloating(shape)) {
341         return InvalidArgument(
342             "Expected element type in shape to be floating "
343             "point for IsFinite "
344             "operation; got %s.",
345             PrimitiveType_Name(shape.element_type()));
346       }
347       return ShapeUtil::ChangeElementType(shape, PRED);
348 
349     default:
350       return InvalidArgument(
351           "Unknown operation for unary shape inference: \"%s\".",
352           HloOpcodeString(opcode));
353   }
354 }
355 
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64 dimension)356 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
357     absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
358   if (arg_shapes.empty()) {
359     return InvalidArgument("Concatenate expects at least one argument.");
360   }
361   if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
362     return InvalidArgument("Concatenate dimension out of bounds: %d.",
363                            dimension);
364   }
365   const Shape* arg_shape = nullptr;
366   PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
367   for (const Shape* shape : arg_shapes) {
368     TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
369     if (!arg_shape) {
370       arg_shape = shape;
371       element_type = arg_shape->element_type();
372       continue;
373     }
374     if (arg_shape->rank() != shape->rank()) {
375       return InvalidArgument(
376           "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
377           "(%s).",
378           arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
379           ShapeUtil::HumanString(*shape));
380     }
381     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
382       return InvalidArgument(
383           "Cannot concatenate arrays with different element types: %s vs %s.",
384           PrimitiveType_Name(arg_shape->element_type()),
385           PrimitiveType_Name(shape->element_type()));
386     }
387     for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
388          ++dimension_number) {
389       if (arg_shape->dimensions(dimension_number) !=
390           shape->dimensions(dimension_number)) {
391         if (dimension_number == dimension) {
392           continue;  // It's okay to differ in the dimension we're
393                      // concatenating.
394         }
395         return InvalidArgument(
396             "Cannot concatenate arrays that differ in dimensions other than "
397             "the one being concatenated (the other array dimensions must be "
398             "the same): %s vs %s in dimension %d.",
399             ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
400             dimension);
401       }
402     }
403     element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
404   }
405 
406   std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
407                                     arg_shape->dimensions().end());
408   for (size_t i = 1; i < arg_shapes.size(); ++i) {
409     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
410   }
411 
412   Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
413 
414   // Set dynamic dimensions if any input has dynamic dimension.
415   for (const Shape* shape : arg_shapes) {
416     for (int64 i = 0; i < shape->dimensions_size(); ++i) {
417       if (shape->is_dynamic_dimension(i)) {
418         result.set_dynamic_dimension(i, true);
419       }
420     }
421   }
422   return result;
423 }
424 
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)425 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
426     const Shape& operand_shape, PrimitiveType new_element_type) {
427   auto old_element_type = operand_shape.element_type();
428   if (primitive_util::IsComplexType(old_element_type) &&
429       !primitive_util::IsComplexType(new_element_type)) {
430     return Unimplemented(
431         "Conversion from complex to real type %s => %s is not implemented.",
432         ShapeUtil::HumanString(operand_shape),
433         PrimitiveType_Name(new_element_type));
434   }
435   if (!operand_shape.IsArray() ||
436       !primitive_util::IsArrayType(new_element_type)) {
437     // Note: we may want to support tuple conversions via this operation in the
438     // future, by recursing into the tuple elements to check all sub-conversions
439     // are valid. For now we just reject them, though.
440     return InvalidArgument(
441         "Convert does not allow non-arrays, so cannot convert from %s to %s.",
442         ShapeUtil::HumanString(operand_shape),
443         PrimitiveType_Name(new_element_type));
444   }
445 
446   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
447 }
448 
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)449 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
450     const Shape& operand_shape, PrimitiveType new_element_type) {
451   auto old_element_type = operand_shape.element_type();
452   if (primitive_util::IsComplexType(old_element_type) !=
453       primitive_util::IsComplexType(new_element_type)) {
454     return InvalidArgument("Conversion from complex to real type %s => %s.",
455                            ShapeUtil::HumanString(operand_shape),
456                            PrimitiveType_Name(new_element_type));
457   }
458   if (!operand_shape.IsArray() ||
459       !primitive_util::IsArrayType(new_element_type)) {
460     // Note: we may want to support tuple conversions via this operation in the
461     // future, by recursing into the tuple elements to check all sub-conversions
462     // are valid. For now we just reject them, though.
463     return InvalidArgument(
464         "Cannot convert from or to tuple type; requested conversion: %s => %s.",
465         ShapeUtil::HumanString(operand_shape),
466         PrimitiveType_Name(new_element_type));
467   }
468   if (primitive_util::BitWidth(old_element_type) !=
469       primitive_util::BitWidth(new_element_type)) {
470     return InvalidArgument(
471         "Cannot bitcast types with different bit-widths: %s => %s.",
472         PrimitiveType_Name(old_element_type),
473         PrimitiveType_Name(new_element_type));
474   }
475 
476   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
477 }
478 
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)479 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
480     const Shape& operand_shape, const int exponent_bits,
481     const int mantissa_bits) {
482   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
483     return InvalidArgument(
484         "Expected element type in shape to be floating point for "
485         "ReducePrecision operation; got %s.",
486         PrimitiveType_Name(operand_shape.element_type()));
487   }
488   if (exponent_bits < 1) {
489     // One exponent bit is necessary to distinguish 0 from infinity.  Having
490     // no exponent bits doesn't produce a sensible number, so we require at
491     // least one.
492     return InvalidArgument("Expected exponent_bits >= 1; got %d.",
493                            exponent_bits);
494   }
495   if (mantissa_bits < 0) {
496     // A number with no mantissa bits is still meaningful, however.
497     return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
498                            mantissa_bits);
499   }
500   return operand_shape;
501 }
502 
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)503 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
504     const Shape& operand_shape, const Shape& padding_value_shape,
505     const PaddingConfig& padding_config) {
506   if (!operand_shape.IsArray()) {
507     return InvalidArgument(
508         "Pad operation does not support tuple-shape operands.");
509   }
510   if (!ShapeUtil::IsScalar(padding_value_shape)) {
511     return InvalidArgument(
512         "Pad operation does not support non-scalar padding values.");
513   }
514   if (operand_shape.rank() != padding_config.dimensions_size()) {
515     return InvalidArgument(
516         "The rank of the operand and the padding configuration do not match: "
517         "%s vs %s.",
518         ShapeUtil::HumanString(operand_shape),
519         padding_config.ShortDebugString());
520   }
521   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
522                                                      padding_value_shape)) {
523     return InvalidArgument(
524         "The element types of the operands to Pad do not match.");
525   }
526   if (absl::c_any_of(padding_config.dimensions(),
527                      [](const PaddingConfig::PaddingConfigDimension& p) {
528                        return p.interior_padding() < 0;
529                      })) {
530     return InvalidArgument("Interior padding cannot be negative: %s",
531                            padding_config.ShortDebugString());
532   }
533 
534   if (!padding_value_shape.is_static()) {
535     return InvalidArgument("Dynamic padding value is not supported");
536   }
537 
538   std::vector<int64> dimensions(operand_shape.rank());
539   std::vector<bool> is_dynamic(operand_shape.rank());
540   for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
541     const auto& p = padding_config.dimensions(i);
542     if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 &&
543         p.edge_padding_low() != 0 && p.interior_padding() != 0) {
544       return InvalidArgument(
545           "Dynamic dimension on padding dimension is not supported.");
546     }
547     dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
548                     p.edge_padding_high() +
549                     std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
550                         p.interior_padding();
551     if (dimensions[i] < 0) {
552       return InvalidArgument("Padding result in negative size for dimension %d",
553                              i);
554     }
555     is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
556   }
557 
558   return ShapeUtil::MakeShape(
559       ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
560       dimensions, is_dynamic);
561 }
562 
563 // Current DotDimensionNumbers Requirements:
564 //
565 // Contracting Dimensions:
566 // *) Same number of contracting dimensions on both lhs and rhs.
567 // *) Contracting dimension size must be the same on both lhs and rhs.
568 //
569 // Batch Dimensions:
570 // *) Same number of batch dimensions on both lhs and rhs.
571 // *) Same batch dimension sizes on both lhs and rhs.
572 //
573 
574 namespace {
575 
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)576 Status ValidateDotDimensionNumbers(
577     const Shape& lhs, const Shape& rhs,
578     const DotDimensionNumbers& dimension_numbers) {
579   // Check that dimension numbers are in range.
580   auto dims_in_range = [](const int64 rank,
581                           absl::Span<const int64> contracting_dims,
582                           absl::Span<const int64> batch_dims) -> bool {
583     auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
584     return absl::c_all_of(contracting_dims, in_range) &&
585            absl::c_all_of(batch_dims, in_range);
586   };
587 
588   absl::Span<const int64> lhs_contracting_dimensions =
589       AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
590   absl::Span<const int64> rhs_contracting_dimensions =
591       AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
592   absl::Span<const int64> lhs_batch_dimensions =
593       AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
594   absl::Span<const int64> rhs_batch_dimensions =
595       AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
596 
597   if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
598                      lhs_batch_dimensions) ||
599       !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
600                      rhs_batch_dimensions)) {
601     return InvalidArgument("A dimension number is out of range in Dot: %s.",
602                            dimension_numbers.DebugString());
603   }
604 
605   // Check that dimension numbers are unique.
606   auto dims_unique = [](absl::Span<const int64> contracting_dims,
607                         absl::Span<const int64> batch_dims) -> bool {
608     absl::flat_hash_set<int64> dim_set;
609     auto is_unique = [&dim_set](int64 i) -> bool {
610       return dim_set.insert(i).second;
611     };
612     return absl::c_all_of(contracting_dims, is_unique) &&
613            absl::c_all_of(batch_dims, is_unique);
614   };
615 
616   if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
617       !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
618     return InvalidArgument("A dimension number is not unique in Dot: %s.",
619                            dimension_numbers.DebugString());
620   }
621 
622   return Status::OK();
623 }
624 
625 }  // namespace
626 
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)627 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
628     const Shape& lhs, const Shape& rhs,
629     const DotDimensionNumbers& dimension_numbers) {
630   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
631   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
632 
633   auto fail = [lhs, rhs](const string& addendum) -> Status {
634     string message =
635         StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
636                   ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
637     if (!addendum.empty()) {
638       message += " " + addendum;
639     }
640     return InvalidArgument("%s", message);
641   };
642 
643   // Check if both element types are the same.
644   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
645     return fail("Element types do not match.");
646   }
647 
648   // Validate basic properties of dot dimension numbers.
649   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
650 
651   // Check that number of contracting dimensions match.
652   if (dimension_numbers.lhs_contracting_dimensions_size() !=
653       dimension_numbers.rhs_contracting_dimensions_size()) {
654     return fail(
655         "Must specify the same number of contracting dimensions for lhs and "
656         "rhs.");
657   }
658   // Check that contracting dimension sizes match.
659   for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
660        ++i) {
661     const int64 lhs_contracting_dimension =
662         dimension_numbers.lhs_contracting_dimensions(i);
663     const int64 rhs_contracting_dimension =
664         dimension_numbers.rhs_contracting_dimensions(i);
665     if (lhs.dimensions(lhs_contracting_dimension) !=
666         rhs.dimensions(rhs_contracting_dimension)) {
667       return fail("Contracting dimension sizes do not match.");
668     }
669   }
670 
671   // Check that number of batch dimensions match.
672   if (dimension_numbers.lhs_batch_dimensions_size() !=
673       dimension_numbers.rhs_batch_dimensions_size()) {
674     return fail("Must the same number of batch dimensions for lhs and rhs.");
675   }
676 
677   // Check that batch dimension numbers and sizes match.
678   for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
679     if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
680         rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
681       return fail("Batch dimension sizes must match for lhs/rhs.");
682     }
683   }
684 
685   // The ranks of lhs and rhs are decremented by 1 respectively due to the
686   // contraction, and added for the rank of the result. When an input tensor is
687   // a scalar, its contribution to the rank of the result is 0.
688   // Generate the result dimensions in order, rhs dimensions followed by lhs
689   // dimensions except the contracted and batch dimensions.
690   std::vector<int64> dimensions;
691   std::vector<bool> is_dynamic;
692   for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
693     dimensions.push_back(lhs.dimensions(lhs_dim));
694     is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
695   }
696   for (int64 i = 0; i < lhs.rank(); i++) {
697     if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
698                                i) &&
699         !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
700       dimensions.push_back(lhs.dimensions(i));
701       is_dynamic.push_back(lhs.is_dynamic_dimension(i));
702     }
703   }
704   for (int64 i = 0; i < rhs.rank(); i++) {
705     if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
706                                i) &&
707         !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
708       dimensions.push_back(rhs.dimensions(i));
709       is_dynamic.push_back(rhs.is_dynamic_dimension(i));
710     }
711   }
712   Shape result = ShapeUtil::MakeShape(
713       ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
714 
715   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
716   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
717   return result;
718 }
719 
720 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)721 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
722                                                        const Shape& lhs,
723                                                        const Shape& rhs) {
724   TF_RET_CHECK(lhs.rank() == rhs.rank());
725 
726   // The shapes have to be compatible. That is, if some dimension d has a
727   // different size in the two shapes, one of them has to be 1 (a "degenerate"
728   // dimension). In that case, the output shape has the non-1 dimension size
729   // from the lhs/rhs pair in every index.
730   std::vector<int64> output_dimensions(lhs.rank());
731   std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
732   for (int64 i = 0; i < lhs.rank(); ++i) {
733     if (lhs.dimensions(i) == rhs.dimensions(i)) {
734       output_dimensions[i] = lhs.dimensions(i);
735     } else if (lhs.dimensions(i) == 1) {
736       output_dimensions[i] = rhs.dimensions(i);
737     } else if (rhs.dimensions(i) == 1) {
738       output_dimensions[i] = lhs.dimensions(i);
739     } else {
740       return InvalidArgument(
741           "Binary op %s with incompatible shapes: %s and %s.",
742           HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
743           ShapeUtil::HumanString(rhs));
744     }
745   }
746 
747   // Merge dynamic dimensions from two shapes.
748   for (int64 i = 0; i < rhs.rank(); ++i) {
749     if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
750       output_dimensions_is_dynamic[i] = true;
751     }
752   }
753 
754   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
755                               output_dimensions, output_dimensions_is_dynamic);
756 }
757 
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)758 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
759     const Shape& smaller_shape, const Shape& larger_shape,
760     absl::Span<const int64> broadcast_dimensions) {
761   if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
762     // Reject "magic" inference for binops on different shapes, requiring
763     // the user to provide an explicit broadcast dimension in this case.
764     // See b/25177275 for more details.
765     return InvalidArgument("Automatic shape inference not supported: %s and %s",
766                            ShapeUtil::HumanString(smaller_shape),
767                            ShapeUtil::HumanString(larger_shape));
768   } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
769     return InvalidArgument(
770         "Size of broadcast_dimensions has to match lower-rank operand's "
771         "rank; "
772         " lower-rank operand's rank is %d, size of broadcast_dimensions is "
773         "%u.",
774         smaller_shape.rank(), broadcast_dimensions.size());
775   }
776 
777   // broadcast_dimensions is a sequence of dimensions; its length is equal to
778   // the rank of the lower-rank operand. The lower-rank operand's dimensions
779   // have to be compatible with the higher-rank operand's dimensions at indices
780   // specified by broadcast_dimensions. Here compatible means the dimension
781   // sizes are equal or in one of the shapes the dimension size is
782   // one. Examples:
783   //
784   // smaller_shape   larger_shape   broadcast_dimensions   output_shape
785   //   []              [2, 3]          {}                    [2, 3]
786   //   [3]             [4, 3]          {1}                   [4, 3]
787   //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
788   //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
789   //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
790   //
791   // The column output_shape may not be the final shape of the XLA
792   // operation. After the "InDim" broadcasting implemented in this function
793   // expands the rank, degenerate-dimension broadcasting (implemented in
794   // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
795   // up to match the dimension size of the other operand. For example, consider
796   // the row in the table above with a smaller_shape of [2, 1]. The shape
797   // returned by this function is [2, 3, 1] (output_shape) however, the result
798   // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
799   // broadcasting.
800   //
801   // Invalid broadcasts:
802   //
803   // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
804   // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
805   //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
806   //   in broadcast_dimensions.
807   //
808   // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
809   // Reason: Dimension one of larger_shape (size 3) is not compatible with
810   //   dimension zero of smaller_shape(size 2)
811 
812   // The output shape is initially the larger_shape. Sizes of dimensions
813   // specified in broadcast_dimensions are then changed to match the
814   // corresponding dimension size in smaller_shape.
815   Shape output_shape(larger_shape);
816   output_shape.set_element_type(
817       ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
818 
819   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
820     int64 dimension_to_match = broadcast_dimensions.at(i);
821     if (dimension_to_match < 0) {
822       return InvalidArgument(
823           "Broadcast dimension number (%d) cannot be negative.",
824           dimension_to_match);
825     }
826     if (dimension_to_match >= larger_shape.dimensions_size()) {
827       return InvalidArgument(
828           "Broadcast dimension number (%d) too large; higher-rank "
829           "operand has rank %d.",
830           dimension_to_match, larger_shape.dimensions_size());
831     }
832     int64 small_dimension_size = smaller_shape.dimensions(i);
833     int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
834     bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
835     bool large_is_dynamic =
836         larger_shape.is_dynamic_dimension(dimension_to_match);
837     // Dimension sizes must be compatible: match or be degenerate (degenerate
838     // case is handled by degenerate dimension broadcasting which occurs after
839     // InDim broadcasting).
840     if (small_dimension_size != large_dimension_size &&
841         small_dimension_size != 1 && large_dimension_size != 1) {
842       return InvalidArgument(
843           "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
844           small_dimension_size, large_dimension_size,
845           ShapeUtil::HumanString(smaller_shape),
846           ShapeUtil::HumanString(larger_shape));
847     }
848     if (small_is_dynamic != large_is_dynamic) {
849       if (small_dimension_size == large_dimension_size ||
850           (small_dimension_size == 1 && !small_is_dynamic) ||
851           (large_dimension_size == 1 && !large_is_dynamic)) {
852         // Do nothing. It's OK when the size-1 dimension is not static.
853       } else {
854         return InvalidArgument(
855             "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
856             ShapeUtil::HumanString(smaller_shape),
857             ShapeUtil::HumanString(larger_shape));
858       }
859     }
860     // Make sure the broadcast dimensions are listed in a strictly increasing
861     // order.
862     if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
863       return InvalidArgument(
864           "Broadcast dimensions order is wrong: %d comes after %d.",
865           dimension_to_match, broadcast_dimensions.at(i - 1));
866     }
867 
868     output_shape.set_dimensions(dimension_to_match, small_dimension_size);
869     output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
870   }
871 
872   return output_shape;
873 }
874 
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)875 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
876     HloOpcode operation, const Shape& lhs, const Shape& rhs,
877     absl::Span<const int64> broadcast_dimensions) {
878   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
879   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
880 
881   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
882     return InvalidArgument(
883         "Binary op %s with different element types: %s and %s.",
884         HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
885         ShapeUtil::HumanString(rhs));
886   }
887 
888   if (lhs.rank() == rhs.rank()) {
889     std::vector<int64> identity_dims(lhs.rank());
890     std::iota(identity_dims.begin(), identity_dims.end(), 0);
891     if (!broadcast_dimensions.empty() &&
892         broadcast_dimensions != identity_dims) {
893       return InvalidArgument(
894           "Broadcast dimensions field must either be not set or be the "
895           "identity on binary operations with operands of the same rank.");
896     }
897   }
898 
899   if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
900     // If the shapes are the same other than layout, the output shape is the
901     // same (elementwise op).
902     Shape result = ShapeUtil::ChangeElementType(
903         lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
904 
905     for (int64 i = 0; i < rhs.rank(); ++i) {
906       if (rhs.is_dynamic_dimension(i)) {
907         result.set_dynamic_dimension(i, true);
908       }
909     }
910 
911     return result;
912 
913   } else if (lhs.rank() == rhs.rank()) {
914     return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
915   } else {
916     // Ranks do not match, so perform InDim broadcasting using
917     // broadcast_dimensions. Scalar broadcasting is a special case of this.
918     const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
919     const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
920 
921     // After InDim broadcasting, perform degenerate dimensions broadcasting.
922     TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
923                         InferInDimBroadcastShape(smaller_shape, larger_shape,
924                                                  broadcast_dimensions));
925 
926     return InferDegenerateDimensionBroadcastShape(
927         operation, indim_broadcast_shape, larger_shape);
928   }
929 }
930 
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)931 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
932     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
933   return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
934                             /*broadcast_dimensions=*/{});
935 }
936 
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)937 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
938     HloOpcode opcode, const Shape& lhs, const Shape& rhs,
939     absl::Span<const int64> broadcast_dimensions) {
940   VLOG(2) << StrFormat(
941       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
942       HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
943       ShapeUtil::HumanStringWithLayout(rhs),
944       StrJoin(broadcast_dimensions, ", "));
945 
946   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
947   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
948 
949   TF_RETURN_IF_ERROR(ExpectArray(
950       lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
951   TF_RETURN_IF_ERROR(ExpectArray(
952       rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
953   switch (opcode) {
954     case HloOpcode::kMaximum:
955     case HloOpcode::kMinimum:
956       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
957                                            broadcast_dimensions);
958 
959     case HloOpcode::kSubtract:
960     case HloOpcode::kAdd:
961     case HloOpcode::kAtan2:
962     case HloOpcode::kPower:
963     case HloOpcode::kDivide:
964     case HloOpcode::kRemainder:
965     case HloOpcode::kMultiply:
966     case HloOpcode::kShiftLeft:
967     case HloOpcode::kShiftRightArithmetic:
968     case HloOpcode::kShiftRightLogical:
969       if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
970         return InvalidArgument(
971             "Expected element type in shape to be arithmetic type for "
972             "operation %s; got PRED.",
973             HloOpcodeString(opcode));
974       }
975       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
976                                            broadcast_dimensions);
977 
978     case HloOpcode::kComplex: {
979       if (!ShapeUtil::ElementIsFloating(lhs)) {
980         return InvalidArgument(
981             "Expected element type in shape to be floating for complex compose "
982             "operation; got %s.",
983             PrimitiveType_Name(lhs.element_type()));
984       }
985       TF_ASSIGN_OR_RETURN(const Shape& shape,
986                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
987                                                         broadcast_dimensions));
988       if (lhs.element_type() == F32 && rhs.element_type() == F32) {
989         return ShapeUtil::ChangeElementType(shape, C64);
990       } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
991         return ShapeUtil::ChangeElementType(shape, C128);
992       } else {
993         return Unimplemented("Complex component type is not implemented.");
994       }
995     }
996     case HloOpcode::kAnd:
997     case HloOpcode::kOr:
998     case HloOpcode::kXor:
999       if (lhs.element_type() != PRED &&
1000           !primitive_util::IsIntegralType(lhs.element_type())) {
1001         return InvalidArgument(
1002             "Expected pred or integral type in argument to and/or operation; "
1003             "got %s.",
1004             PrimitiveType_Name(lhs.element_type()));
1005       }
1006       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1007                                            broadcast_dimensions);
1008     case HloOpcode::kCompare: {
1009       TF_ASSIGN_OR_RETURN(const Shape& shape,
1010                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1011                                                         broadcast_dimensions));
1012       return ShapeUtil::ChangeElementType(shape, PRED);
1013     }
1014     default:
1015       return Unimplemented(
1016           "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1017           HloOpcodeString(opcode), lhs.ShortDebugString(),
1018           rhs.ShortDebugString());
1019   }
1020 }
1021 
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1022 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1023     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1024     const HloInstruction* ehs) {
1025   return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1026 }
1027 
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1028 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1029     HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1030   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1031   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1032   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1033   switch (opcode) {
1034     case HloOpcode::kClamp:
1035       return InferClampShape(lhs, rhs, ehs);
1036     case HloOpcode::kSelect:
1037       return InferSelectShape(lhs, rhs, ehs);
1038     case HloOpcode::kTupleSelect:
1039       return InferTupleSelectShape(lhs, rhs, ehs);
1040     default:
1041       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1042   }
1043 }
1044 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1045 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1046     HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1047   std::vector<const Shape*> operand_shapes;
1048   operand_shapes.reserve(operands.size());
1049   for (const HloInstruction* operand : operands) {
1050     operand_shapes.push_back(&operand->shape());
1051   }
1052   return InferVariadicOpShape(opcode, operand_shapes);
1053 }
1054 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1055 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1056     HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1057   for (const Shape* shape : operand_shapes) {
1058     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1059   }
1060   switch (opcode) {
1061     case HloOpcode::kTuple: {
1062       Shape result = ShapeUtil::MakeTupleShape({});
1063       result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1064       for (const Shape* shape : operand_shapes) {
1065         ShapeUtil::AppendShapeToTuple(*shape, &result);
1066       }
1067       return result;
1068     }
1069     case HloOpcode::kSort: {
1070       if (operand_shapes.size() == 1) {
1071         return *operand_shapes[0];
1072       } else {
1073         for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
1074           if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1075                                          *operand_shapes[operand])) {
1076             return InvalidArgument(
1077                 "Sort keys and values dimensions must match. "
1078                 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1079                 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1080                 ShapeUtil::HumanString(*operand_shapes[operand]));
1081           }
1082         }
1083         std::vector<Shape> operand_shape_values;
1084         for (const Shape* operand_shape : operand_shapes) {
1085           operand_shape_values.push_back(*operand_shape);
1086         }
1087         return ShapeUtil::MakeTupleShape(operand_shape_values);
1088       }
1089       return InvalidArgument("Unexpected number of operands for sort");
1090     }
1091     default:
1092       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1093   }
1094 }
1095 
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1096 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1097     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1098     absl::Span<const int64> dimensions) {
1099   if (arg_shapes.empty()) {
1100     return InvalidArgument("Map expects at least one argument.");
1101   }
1102 
1103   // All arguments must have the same shape.
1104   const Shape* arg_shape = arg_shapes[0];
1105   for (size_t i = 1; i < arg_shapes.size(); ++i) {
1106     TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1107 
1108     if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1109       continue;
1110     }
1111     if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1112                                                       *arg_shape)) {
1113       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1114         continue;
1115       }
1116       if (ShapeUtil::IsScalar(*arg_shape)) {
1117         arg_shape = arg_shapes[i];
1118         continue;
1119       }
1120     }
1121 
1122     std::vector<string> pieces;
1123     for (const Shape* shape : arg_shapes) {
1124       pieces.push_back(ShapeUtil::HumanString(*shape));
1125     }
1126     return InvalidArgument(
1127         "Map operation requires all operands to have the same shape; got: "
1128         "%s.",
1129         StrJoin(pieces, ", "));
1130   }
1131 
1132   // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1133   // only support mapping across all dimensions: i.e. scalar map functions).
1134   if (dimensions.size() != arg_shape->dimensions_size()) {
1135     return InvalidArgument(
1136         "Map applied to a subset of dimensions currently not supported: "
1137         "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1138         arg_shape->dimensions_size(), dimensions.size());
1139   }
1140 
1141   // Check that requested map dimensions numbers are monotonically increasing.
1142   for (int i = 0; i < dimensions.size(); ++i) {
1143     if (dimensions[i] != i) {
1144       return InvalidArgument(
1145           "Map requires monotonically increasing dimension numbers; got: %s.",
1146           StrJoin(dimensions, ", "));
1147     }
1148   }
1149 
1150   // The applied function's arity equals the number of arguments.
1151   if (arg_shapes.size() != to_apply.parameters_size()) {
1152     return InvalidArgument(
1153         "Map applied function arity must match number of arguments; got: "
1154         "arity: %d, arguments: %u.",
1155         to_apply.parameters_size(), arg_shapes.size());
1156   }
1157 
1158   // The parameters should all be scalars, and the output too.
1159   const Shape& output_shape = to_apply.result();
1160   if (!ShapeUtil::IsScalar(output_shape)) {
1161     return InvalidArgument(
1162         "Mapped computation's result has to be a scalar; got: %s.",
1163         ShapeUtil::HumanString(output_shape));
1164   }
1165 
1166   for (int i = 0; i < to_apply.parameters_size(); ++i) {
1167     const Shape& parameter_shape = to_apply.parameters(i);
1168 
1169     if (!ShapeUtil::IsScalar(parameter_shape)) {
1170       return InvalidArgument(
1171           "Mapped computation's parameter has to be a scalar; "
1172           "got parameter %d shape: %s.",
1173           i, ShapeUtil::HumanString(parameter_shape));
1174     }
1175 
1176     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1177                                                        *arg_shape)) {
1178       return InvalidArgument(
1179           "Mapped computation's parameter type has to match argument element "
1180           "type; got parameter %d shape: %s, argument shape: %s.",
1181           i, ShapeUtil::HumanString(parameter_shape),
1182           ShapeUtil::HumanString(*arg_shape));
1183     }
1184   }
1185 
1186   return ShapeUtil::MakeShape(output_shape.element_type(),
1187                               AsInt64Slice(arg_shape->dimensions()));
1188 }
1189 
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1190 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1191     const Shape& operand_shape, const Shape& scale_shape,
1192     const Shape& offset_shape, int64 feature_index) {
1193   TF_RETURN_IF_ERROR(
1194       ExpectArray(operand_shape, "operand of batch norm training"));
1195   TF_RETURN_IF_ERROR(
1196       ExpectArray(offset_shape, "offset input of batch norm training"));
1197   TF_RETURN_IF_ERROR(
1198       ExpectArray(scale_shape, "scale input of batch norm training"));
1199 
1200   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1201                Status::OK());
1202   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1203                Status::OK());
1204   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1205                Status::OK());
1206 
1207   if (feature_index >= operand_shape.rank()) {
1208     return InvalidArgument(
1209         "Expected feature_index of batch-norm-training to be "
1210         "smaller than the rank of operand_shape; "
1211         "got feature_index %d, and rank %d.",
1212         feature_index, operand_shape.rank());
1213   }
1214 
1215   if (feature_index < 0) {
1216     return InvalidArgument(
1217         "Expected feature_index of batch-norm-training to "
1218         "be a non-negative number, got %d.",
1219         feature_index);
1220   }
1221 
1222   if (operand_shape.rank() < 1) {
1223     return InvalidArgument(
1224         "Expected the rank of operand to "
1225         "batch-norm-training to be at least 1; got %d.",
1226         operand_shape.rank());
1227   }
1228 
1229   if (offset_shape.rank() != 1) {
1230     return InvalidArgument(
1231         "Offset input of batch-norm-training must have"
1232         " rank 1, but has rank %d.",
1233         offset_shape.rank());
1234   }
1235 
1236   if (scale_shape.rank() != 1) {
1237     return InvalidArgument(
1238         "Scale input of batch-norm-training must have"
1239         " rank 1, but has rank %d.",
1240         scale_shape.rank());
1241   }
1242 
1243   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1244     return InvalidArgument(
1245         "The operand to batch-norm-training must have a floating point "
1246         "element type, but the shape is %s.",
1247         PrimitiveType_Name(operand_shape.element_type()));
1248   }
1249 
1250   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1251                                                      operand_shape)) {
1252     return InvalidArgument(
1253         "The inputs should have the same element type for batch-norm-training, "
1254         "but the shape of offset factor is %s "
1255         "and the shape of operand is %s.",
1256         PrimitiveType_Name(offset_shape.element_type()),
1257         PrimitiveType_Name(operand_shape.element_type()));
1258   }
1259 
1260   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1261                                                      operand_shape)) {
1262     return InvalidArgument(
1263         "The inputs should have the same element type for batch-norm-training, "
1264         "but the shape of scale factor is %s "
1265         "and the shape of operand is %s.",
1266         PrimitiveType_Name(scale_shape.element_type()),
1267         PrimitiveType_Name(operand_shape.element_type()));
1268   }
1269 
1270   const int64 feature_count = operand_shape.dimensions(feature_index);
1271   Shape output_shape_for_mean_and_var =
1272       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1273 
1274   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1275     return InvalidArgument(
1276         "The size of offset factor should be the same as feature count,"
1277         "but the size of offset factor is %d "
1278         "and the feature count is %d.",
1279         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1280   }
1281 
1282   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1283     return InvalidArgument(
1284         "The size of scale factor should be the same as feature count,"
1285         "but the size of scale factor is %d "
1286         "and the feature count is %d.",
1287         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1288   }
1289 
1290   return ShapeUtil::MakeTupleShape({operand_shape,
1291                                     output_shape_for_mean_and_var,
1292                                     output_shape_for_mean_and_var});
1293 }
1294 
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1295 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1296     const Shape& operand_shape, const Shape& scale_shape,
1297     const Shape& offset_shape, const Shape& mean_shape,
1298     const Shape& variance_shape, int64 feature_index) {
1299   TF_RETURN_IF_ERROR(
1300       ExpectArray(operand_shape, "operand of batch norm inference"));
1301   TF_RETURN_IF_ERROR(
1302       ExpectArray(offset_shape, "offset input of batch norm inference"));
1303   TF_RETURN_IF_ERROR(
1304       ExpectArray(scale_shape, "scale input of batch norm inference"));
1305 
1306   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1307   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1308   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1309   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1310   TF_RETURN_IF_ERROR(
1311       ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1312 
1313   if (feature_index >= operand_shape.rank()) {
1314     return InvalidArgument(
1315         "Expected feature_index of batch-norm-inference to be "
1316         "smaller than the rank of operand_shape; "
1317         "got feature_index %d, and rank %d.",
1318         feature_index, operand_shape.rank());
1319   }
1320 
1321   if (feature_index < 0) {
1322     return InvalidArgument(
1323         "Expected feature_index of batch-norm-inference to "
1324         "be a non-negative number, got %d.",
1325         feature_index);
1326   }
1327 
1328   if (operand_shape.rank() < 1) {
1329     return InvalidArgument(
1330         "Expected the rank of operand to "
1331         "batch-norm-inference to be at least 1; got %d.",
1332         operand_shape.rank());
1333   }
1334 
1335   if (offset_shape.rank() != 1) {
1336     return InvalidArgument(
1337         "Offset input of batch-norm-inference must have"
1338         " rank 1, but has rank %d.",
1339         offset_shape.rank());
1340   }
1341 
1342   if (scale_shape.rank() != 1) {
1343     return InvalidArgument(
1344         "Scale input of batch-norm-inference must have"
1345         " rank 1, but has rank %d.",
1346         scale_shape.rank());
1347   }
1348 
1349   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1350     return InvalidArgument(
1351         "The operand to batch-norm-inference must have a floating point "
1352         "element type, but the shape is %s.",
1353         PrimitiveType_Name(operand_shape.element_type()));
1354   }
1355 
1356   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1357                                                      operand_shape)) {
1358     return InvalidArgument(
1359         "The inputs should have the same element type for "
1360         "batch-norm-inference, "
1361         "but the shape of offset factor is %s "
1362         "and the shape of operand is %s.",
1363         PrimitiveType_Name(offset_shape.element_type()),
1364         PrimitiveType_Name(operand_shape.element_type()));
1365   }
1366 
1367   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1368                                                      operand_shape)) {
1369     return InvalidArgument(
1370         "The inputs should have the same element type for "
1371         "batch-norm-inference, "
1372         "but the shape of scale factor is %s "
1373         "and the shape of operand is %s.",
1374         PrimitiveType_Name(scale_shape.element_type()),
1375         PrimitiveType_Name(operand_shape.element_type()));
1376   }
1377 
1378   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1379                                                      operand_shape)) {
1380     return InvalidArgument(
1381         "The inputs should have the same element type for "
1382         "batch-norm-inference, "
1383         "but the shape of mean is %s "
1384         "and the shape of operand is %s.",
1385         PrimitiveType_Name(mean_shape.element_type()),
1386         PrimitiveType_Name(operand_shape.element_type()));
1387   }
1388 
1389   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1390                                                      operand_shape)) {
1391     return InvalidArgument(
1392         "The inputs should have the same element type for "
1393         "batch-norm-inference, "
1394         "but the shape of variance is %s "
1395         "and the shape of operand is %s.",
1396         PrimitiveType_Name(mean_shape.element_type()),
1397         PrimitiveType_Name(variance_shape.element_type()));
1398   }
1399 
1400   const int64 feature_count = operand_shape.dimensions(feature_index);
1401   Shape output_shape_for_mean_and_var =
1402       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1403 
1404   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1405     return InvalidArgument(
1406         "The size of offset factor should be the same as feature count,"
1407         "but the size of offset factor is %d "
1408         "and the feature count is %d.",
1409         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1410   }
1411 
1412   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1413     return InvalidArgument(
1414         "The size of scale factor should be the same as feature count,"
1415         "but the size of scale factor is %d "
1416         "and the feature count is %d.",
1417         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1418   }
1419 
1420   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1421     return InvalidArgument(
1422         "The size of mean should be the same as feature count,"
1423         "but the size of mean is %d "
1424         "and the feature count is %d.",
1425         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1426   }
1427 
1428   if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1429     return InvalidArgument(
1430         "The size of variance should be the same as feature count,"
1431         "but the size of variance is %d "
1432         "and the feature count is %d.",
1433         ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1434   }
1435 
1436   return operand_shape;
1437 }
1438 
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1439 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1440     const Shape& operand_shape, const Shape& scale_shape,
1441     const Shape& mean_shape, const Shape& var_shape,
1442     const Shape& output_grad_shape, int64 feature_index) {
1443   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1444   TF_RETURN_IF_ERROR(
1445       ExpectArray(scale_shape, "scale input of batch norm grad"));
1446   TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1447   TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1448   TF_RETURN_IF_ERROR(
1449       ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1450 
1451   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1452   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1453   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1454   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1455   TF_RETURN_IF_ERROR(
1456       ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1457 
1458   if (feature_index >= operand_shape.rank()) {
1459     return InvalidArgument(
1460         "Expected feature_index of batch-norm-grad to be "
1461         "smaller than the rank of operand_shape; "
1462         "got feature_index %d, and rank %d.",
1463         feature_index, operand_shape.rank());
1464   }
1465 
1466   if (operand_shape.rank() != output_grad_shape.rank()) {
1467     return InvalidArgument(
1468         "Expected operand_shape of batch-norm-grad to have the same rank as"
1469         " output_grad_shape; got rank(oprand_shape) %d, and"
1470         " rank(output_grad_shape) %d.",
1471         operand_shape.rank(), output_grad_shape.rank());
1472   }
1473 
1474   if (mean_shape.rank() != 1) {
1475     return InvalidArgument(
1476         "Mean input of batch-norm-grad must have"
1477         " rank 1, but has rank %d.",
1478         mean_shape.rank());
1479   }
1480 
1481   if (scale_shape.rank() != 1) {
1482     return InvalidArgument(
1483         "Scale input of batch-norm-grad must have"
1484         " rank 1, but has rank %d.",
1485         scale_shape.rank());
1486   }
1487 
1488   if (var_shape.rank() != 1) {
1489     return InvalidArgument(
1490         "Var input of batch-norm-grad must have"
1491         " rank 1, but has rank %d.",
1492         var_shape.rank());
1493   }
1494 
1495   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1496     return InvalidArgument(
1497         "The operand to batch-norm-grad must have a floating point "
1498         "element type, but the shape is %s.",
1499         PrimitiveType_Name(operand_shape.element_type()));
1500   }
1501 
1502   if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1503     return InvalidArgument(
1504         "The output_grad to batch-norm-grad must have a floating point "
1505         "element type, but the shape is %s.",
1506         PrimitiveType_Name(output_grad_shape.element_type()));
1507   }
1508 
1509   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1510                                                      operand_shape)) {
1511     return InvalidArgument(
1512         "The inputs should have the same element type for batch-norm-grad, "
1513         "but the element type of output_grad is %s "
1514         "and the element type of operand is %s.",
1515         PrimitiveType_Name(output_grad_shape.element_type()),
1516         PrimitiveType_Name(operand_shape.element_type()));
1517   }
1518 
1519   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1520                                                      operand_shape)) {
1521     return InvalidArgument(
1522         "The inputs should have the same element type for batch-norm-grad, "
1523         "but the element type of scale factor is %s "
1524         "and the element type of operand is %s.",
1525         PrimitiveType_Name(scale_shape.element_type()),
1526         PrimitiveType_Name(operand_shape.element_type()));
1527   }
1528 
1529   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1530                                                      operand_shape)) {
1531     return InvalidArgument(
1532         "The inputs should have the same element type for batch-norm-grad, "
1533         "but the element type of mean is %s "
1534         "and the element type of operand is %s.",
1535         PrimitiveType_Name(mean_shape.element_type()),
1536         PrimitiveType_Name(operand_shape.element_type()));
1537   }
1538 
1539   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1540                                                      operand_shape)) {
1541     return InvalidArgument(
1542         "The inputs should have the same element type for batch-norm-grad, "
1543         "but the element type of mean is %s "
1544         "and the element type of operand is %s.",
1545         PrimitiveType_Name(mean_shape.element_type()),
1546         PrimitiveType_Name(operand_shape.element_type()));
1547   }
1548 
1549   const int64 feature_count = operand_shape.dimensions(feature_index);
1550 
1551   Shape feature_shape =
1552       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1553 
1554   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1555     return InvalidArgument(
1556         "The size of mean should be the same as feature count,"
1557         "but the size of offset factor is %d "
1558         "and the feature count is %d.",
1559         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1560   }
1561 
1562   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1563     return InvalidArgument(
1564         "The size of scale factor should be the same as feature count,"
1565         "but the size of scale factor is %d "
1566         "and the feature count is %d.",
1567         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1568   }
1569 
1570   if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1571     return InvalidArgument(
1572         "The size of variance should be the same as feature count,"
1573         "but the size of variance is %d "
1574         "and the feature count is %d.",
1575         ShapeUtil::GetDimension(var_shape, 0), feature_count);
1576   }
1577 
1578   // Verify operand_shape and output_grad_shape have same bounds.
1579   for (int64 i = 0; i < operand_shape.rank(); ++i) {
1580     if (ShapeUtil::GetDimension(operand_shape, i) !=
1581         ShapeUtil::GetDimension(output_grad_shape, i)) {
1582       return InvalidArgument(
1583           "The bounds of operand shape should be the same as output_grad's,"
1584           "but the bound of operand_shape at dimension %d is %d "
1585           "and the bound of output_grad_shape is %d.",
1586           i, ShapeUtil::GetDimension(operand_shape, i),
1587           ShapeUtil::GetDimension(output_grad_shape, i));
1588     }
1589   }
1590 
1591   return ShapeUtil::MakeTupleShape(
1592       {operand_shape, feature_shape, feature_shape});
1593 }
1594 
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums)1595 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1596     const Shape& lhs, const Shape& rhs, int64 feature_group_count,
1597     int64 batch_group_count, const Window& window,
1598     const ConvolutionDimensionNumbers& dnums) {
1599   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1600   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1601 
1602   if (feature_group_count <= 0) {
1603     return InvalidArgument(
1604         "feature_group_count must be a positive number, got %d",
1605         feature_group_count);
1606   }
1607 
1608   if (batch_group_count <= 0) {
1609     return InvalidArgument(
1610         "batch_group_count must be a positive number, got %d",
1611         batch_group_count);
1612   }
1613 
1614   if (batch_group_count > 1 && feature_group_count > 1) {
1615     return InvalidArgument(
1616         "both batch_group_count %d and feature_group_count %d cannot be "
1617         "greater than 1",
1618         batch_group_count, feature_group_count);
1619   }
1620 
1621   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
1622     return InvalidArgument(
1623         "Convolution with different element types: %s and %s.",
1624         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
1625   }
1626   if (dnums.input_spatial_dimensions_size() !=
1627       dnums.kernel_spatial_dimensions_size()) {
1628     return InvalidArgument(
1629         "Both arguments to convolution must have same number of dimensions.\n"
1630         "Numbers: %s",
1631         dnums.DebugString());
1632   }
1633 
1634   if (dnums.input_spatial_dimensions_size() !=
1635       dnums.output_spatial_dimensions_size()) {
1636     return InvalidArgument(
1637         "Both input and output of convolution must have same number of "
1638         "dimensions.\nNumbers: %s",
1639         dnums.DebugString());
1640   }
1641 
1642   const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1643   if (window.dimensions_size() != num_spatial_dims) {
1644     return InvalidArgument(
1645         "Window must have same number of dimensions as dimension numbers.\n"
1646         "Window: %s\nDimension numbers: %s.",
1647         window.DebugString(), dnums.DebugString());
1648   }
1649 
1650   const int num_dims = num_spatial_dims + 2;
1651   if (lhs.rank() != num_dims) {
1652     return InvalidArgument(
1653         "The LHS argument to a convolution should have rank %d; lhs: %s.",
1654         num_dims, ShapeUtil::HumanString(lhs));
1655   }
1656   if (rhs.rank() != num_dims) {
1657     return InvalidArgument(
1658         "The RHS argument to a convolution should have rank %d; rhs: %s.",
1659         num_dims, ShapeUtil::HumanString(rhs));
1660   }
1661   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1662   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1663 
1664   // Verifies that the input and window dimensions are a permutation of
1665   // the dimension numbers.
1666   std::vector<int64> input_dnums(num_dims);
1667   input_dnums[0] = dnums.input_batch_dimension();
1668   input_dnums[1] = dnums.input_feature_dimension();
1669   std::copy(dnums.input_spatial_dimensions().begin(),
1670             dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1671   absl::c_sort(input_dnums);
1672 
1673   std::vector<int64> window_dnums(num_dims);
1674   window_dnums[0] = dnums.kernel_input_feature_dimension();
1675   window_dnums[1] = dnums.kernel_output_feature_dimension();
1676   std::copy(dnums.kernel_spatial_dimensions().begin(),
1677             dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1678   absl::c_sort(window_dnums);
1679 
1680   std::vector<int64> output_dnums(num_dims);
1681   output_dnums[0] = dnums.output_batch_dimension();
1682   output_dnums[1] = dnums.output_feature_dimension();
1683   std::copy(dnums.output_spatial_dimensions().begin(),
1684             dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1685   absl::c_sort(output_dnums);
1686 
1687   std::vector<int64> expected_dnums(num_dims);
1688   std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1689 
1690   const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1691   if (!absl::c_all_of(input_dnums, in_range) ||
1692       !absl::c_all_of(window_dnums, in_range) ||
1693       !absl::c_all_of(output_dnums, in_range)) {
1694     return InvalidArgument(
1695         "A dimension number is out of range in convolution: %s.",
1696         dnums.DebugString());
1697   }
1698 
1699   if (input_dnums != expected_dnums) {
1700     return InvalidArgument(
1701         "Input dimensions of convolution must contain each dimension exactly "
1702         "once: %s.",
1703         dnums.DebugString());
1704   }
1705   if (window_dnums != expected_dnums) {
1706     return InvalidArgument(
1707         "Window dimensions of convolution must contain each dimension exactly "
1708         "once: %s.",
1709         dnums.DebugString());
1710   }
1711   if (output_dnums != expected_dnums) {
1712     return InvalidArgument(
1713         "Output dimensions of convolution must contain each dimension exactly "
1714         "once: %s.",
1715         dnums.DebugString());
1716   }
1717 
1718   std::vector<int64> input_spatial_dims(num_spatial_dims);
1719   for (int i = 0; i < num_spatial_dims; ++i) {
1720     input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1721   }
1722   const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1723   const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1724 
1725   std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1726   for (int i = 0; i < num_spatial_dims; ++i) {
1727     kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1728   }
1729   const int64 kernel_input_features =
1730       rhs.dimensions(dnums.kernel_input_feature_dimension());
1731   const int64 kernel_output_features =
1732       rhs.dimensions(dnums.kernel_output_feature_dimension());
1733 
1734   if (kernel_output_features % batch_group_count != 0) {
1735     return InvalidArgument(
1736         "Expected output feature dimension size (value %d) to be a multiple of "
1737         "batch group count %d; got <conv>(%s, %s)\n"
1738         "Dimension numbers: {%s}.",
1739         kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1740         ShapeUtil::HumanString(rhs), dnums.DebugString());
1741   }
1742 
1743   if (input_features % feature_group_count != 0 ||
1744       input_features / feature_group_count != kernel_input_features) {
1745     return InvalidArgument(
1746         "Expected LHS feature dimension (value %d) to be a multiple of "
1747         "feature_group_count (value %d), and LHS feature dimension / "
1748         "feature_group_count = RHS feature dimension (value %d); "
1749         "got <conv>(%s, %s)\n"
1750         "Dimension numbers: {%s}.",
1751         input_features, feature_group_count, kernel_input_features,
1752         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1753         dnums.DebugString());
1754   }
1755 
1756   if (kernel_output_features % feature_group_count > 0) {
1757     // A depthwise/grouped filter has the shape
1758     // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1759     // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1760     // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1761     // feature count (which is equal to kernel output features) has to be a
1762     // multiple of feature_group_count.
1763     return InvalidArgument(
1764         "Expected output feature dimension (value %d) to be divisible by "
1765         "feature_group_count (value %d); "
1766         "got <conv>(%s, %s)\n"
1767         "Dimension numbers: {%s}.",
1768         kernel_output_features, feature_group_count,
1769         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1770         dnums.DebugString());
1771   }
1772 
1773   if (input_batch % batch_group_count != 0) {
1774     return InvalidArgument(
1775         "Expected input batch dimension (value %d) to be divisible by "
1776         "batch_group_count (value %d); "
1777         "got <conv>(%s, %s)\n"
1778         "Dimension numbers: {%s}.",
1779         input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1780         ShapeUtil::HumanString(rhs), dnums.DebugString());
1781   }
1782 
1783   std::vector<int64> window_dims(num_spatial_dims);
1784   for (int i = 0; i < num_spatial_dims; ++i) {
1785     window_dims[i] = window.dimensions(i).size();
1786   }
1787   if (kernel_spatial_dims != window_dims) {
1788     return InvalidArgument(
1789         "Window dimensions do not match RHS shape:\n\t"
1790         "RHS shape: %s\n\t"
1791         "Window: {%s}\n\t"
1792         "Dimension numbers: {%s}.",
1793         ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1794         dnums.ShortDebugString());
1795   }
1796 
1797   Shape base_shape =
1798       ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1799   TF_ASSIGN_OR_RETURN(
1800       Shape window_output_shape,
1801       InferWindowOutputShape(base_shape, window, lhs.element_type(),
1802                              /*allow_negative_padding=*/true));
1803 
1804   std::vector<int64> dimensions(num_dims);
1805   dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1806   dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1807 
1808   for (int i = 0; i < num_spatial_dims; ++i) {
1809     dimensions[dnums.output_spatial_dimensions(i)] =
1810         window_output_shape.dimensions(i);
1811   }
1812   std::vector<bool> is_dynamic(num_dims);
1813   for (int i = 0; i < num_dims; i++) {
1814     if (lhs.is_dynamic_dimension(i)) {
1815       if (i == dnums.input_batch_dimension()) {
1816         is_dynamic[dnums.output_batch_dimension()] = true;
1817       } else if (i == dnums.input_feature_dimension()) {
1818         // Input feature dimension is a contracting dimension, which does not
1819         // affect the output dimension size. So we need to do nothing.
1820       } else {
1821         return InvalidArgument(
1822             "Dynamic Spatial Convolution is not supported: lhs shape is %s ",
1823             lhs.ToString());
1824       }
1825     }
1826     if (rhs.is_dynamic_dimension(i)) {
1827       if (i == dnums.kernel_input_feature_dimension()) {
1828         // Kernel feature dimension does not affect the output dimension size.
1829         // So we need to do nothing.
1830       } else {
1831         return InvalidArgument(
1832             "Dynamic Spatial Convolution is not supported: rhs shape is %s ",
1833             rhs.ToString());
1834       }
1835     }
1836   }
1837   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1838                               dimensions, is_dynamic);
1839 }
1840 
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1841 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1842     const Shape& in, const FftType fft_type,
1843     const absl::Span<const int64> fft_length) {
1844   const int64 fft_rank = fft_length.size();
1845   if (fft_rank < 1 || fft_rank > 3) {
1846     return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1847   }
1848 #define RET_CHECK_RANK(x)                            \
1849   if (x.dimensions_size() < fft_rank) {              \
1850     return InvalidArgument(                          \
1851         "FFT of rank %d requires input of at least " \
1852         "same rank; got input of rank %d",           \
1853         fft_rank, x.dimensions_size());              \
1854   }
1855   switch (fft_type) {
1856     case FFT:
1857     case IFFT:
1858       if (in.element_type() != C64) {
1859         return InvalidArgument("%s requires complex input type, found %s.",
1860                                FftType_Name(fft_type),
1861                                PrimitiveType_Name(in.element_type()));
1862       }
1863       RET_CHECK_RANK(in);
1864       return in;
1865     case RFFT: {
1866       if (in.element_type() != F32) {
1867         return InvalidArgument("RFFT requires F32 input type, found %s.",
1868                                PrimitiveType_Name(in.element_type()));
1869       }
1870       RET_CHECK_RANK(in);
1871       for (int i = 0; i < fft_rank; i++) {
1872         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1873             fft_length[i]) {
1874           return InvalidArgument(
1875               "RFFT requires innermost dimensions match fft_length but "
1876               "dimension %d is %d and should be %d.",
1877               in.dimensions_size() - fft_rank + i,
1878               in.dimensions(in.dimensions_size() - fft_rank + i),
1879               fft_length[i]);
1880         }
1881       }
1882       Shape result = ShapeUtil::ChangeElementType(in, C64);
1883       // Preserve the size of zero-sized dimensions.
1884       if (fft_length[fft_rank - 1] != 0) {
1885         result.set_dimensions(result.dimensions_size() - 1,
1886                               fft_length[fft_rank - 1] / 2 + 1);
1887       }
1888       return result;
1889     }
1890     case IRFFT: {
1891       if (in.element_type() != C64) {
1892         return InvalidArgument("IRFFT requires C64 input type, found %s.",
1893                                PrimitiveType_Name(in.element_type()));
1894       }
1895       RET_CHECK_RANK(in);
1896       Shape result = ShapeUtil::ComplexComponentShape(in);
1897       for (int i = 0; i < fft_rank - 1; i++) {
1898         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1899             fft_length[i]) {
1900           return InvalidArgument(
1901               "IRFFT requires all but one innermost dimensions match "
1902               "fft_length, but dimension %d is %d and should be %d.",
1903               in.dimensions_size() - fft_rank + i,
1904               in.dimensions(in.dimensions_size() - fft_rank + i),
1905               fft_length[i]);
1906         }
1907       }
1908       // The size of zero-sized dimensions is preserved.
1909       if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1910            fft_length[fft_rank - 1] != 0) &&
1911           in.dimensions(in.dimensions_size() - 1) !=
1912               fft_length[fft_rank - 1] / 2 + 1) {
1913         return InvalidArgument(
1914             "IRFFT requires innermost dimension matches fft_length/2+1, but "
1915             "dimension %d is %d and should be %d.",
1916             in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1917             fft_length[fft_rank - 1] / 2 + 1);
1918       }
1919       result.set_dimensions(result.dimensions_size() - 1,
1920                             fft_length[fft_rank - 1]);
1921       return result;
1922     }
1923     default:
1924       LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1925   }
1926 #undef RET_CHECK_RANK
1927 }
1928 
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1929 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1930     const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1931   if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1932       a.element_type() != b.element_type()) {
1933     return InvalidArgument(
1934         "Expected element types in shape to be floating or complex and "
1935         "identical for TriangularSolve; got %s and %s.",
1936         PrimitiveType_Name(a.element_type()),
1937         PrimitiveType_Name(b.element_type()));
1938   }
1939   if (a.rank() < 2) {
1940     return InvalidArgument(
1941         "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1942         a.ToString());
1943   }
1944   if (b.rank() != a.rank()) {
1945     return InvalidArgument(
1946         "Arguments to triangular solve must have equal rank; got %s and %s.",
1947         b.ToString(), a.ToString());
1948   }
1949   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1950     return InvalidArgument(
1951         "The two minor dimensions of 'a' must have equal size, got %s.",
1952         a.ToString());
1953   }
1954   if (a.dimensions(a.rank() - 1) !=
1955       b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1956     return InvalidArgument(
1957         "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1958         "%s",
1959         a.ToString(), b.ToString());
1960   }
1961   absl::Span<const int64> a_batch_dims(a.dimensions());
1962   absl::Span<const int64> b_batch_dims(b.dimensions());
1963   a_batch_dims.remove_suffix(2);
1964   b_batch_dims.remove_suffix(2);
1965   if (a_batch_dims != b_batch_dims) {
1966     return InvalidArgument(
1967         "The leading batch dimensions of the arguments to triangular solve "
1968         "must be equal; got %s and %s.",
1969         b.ToString(), a.ToString());
1970   }
1971   if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
1972       options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
1973     return InvalidArgument(
1974         "Invalid transpose option value for triangular solve (%d).\n",
1975         options.transpose_a());
1976   }
1977   return b;
1978 }
1979 
InferCholeskyShape(const Shape & a)1980 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
1981     const Shape& a) {
1982   if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
1983     return InvalidArgument(
1984         "Expected element type in shape to be floating or complex for "
1985         "Cholesky; got %s.",
1986         PrimitiveType_Name(a.element_type()));
1987   }
1988   if (a.rank() < 2) {
1989     return InvalidArgument(
1990         "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
1991         a.ToString());
1992   }
1993   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1994     return InvalidArgument(
1995         "The two minor dimensions of 'a' must have equal size, got %s.",
1996         a.ToString());
1997   }
1998   return a;
1999 }
2000 
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2001 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2002     absl::Span<const Shape* const> operand_shapes) {
2003   for (const Shape* operand_shape : operand_shapes) {
2004     TF_RETURN_IF_ERROR(
2005         ExpectArray(*operand_shape, "operand of cross replica sum"));
2006   }
2007   if (operand_shapes.size() == 1) {
2008     return *operand_shapes[0];
2009   }
2010   std::vector<Shape> operand_shape_values;
2011   for (const Shape* operand_shape : operand_shapes) {
2012     operand_shape_values.push_back(*operand_shape);
2013   }
2014   return ShapeUtil::MakeTupleShape(operand_shape_values);
2015 }
2016 
InferAllToAllShape(const Shape & shape,int64 split_dimension,int64 concat_dimension,int64 split_count)2017 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2018     const Shape& shape, int64 split_dimension, int64 concat_dimension,
2019     int64 split_count) {
2020   TF_RET_CHECK(split_count > 0);
2021   if (split_dimension >= shape.rank() || split_dimension < 0) {
2022     return InvalidArgument(
2023         "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2024         split_dimension, ShapeUtil::HumanString(shape));
2025   }
2026   if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2027     return InvalidArgument(
2028         "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2029         concat_dimension, ShapeUtil::HumanString(shape));
2030   }
2031   if (shape.dimensions(split_dimension) % split_count != 0) {
2032     return InvalidArgument(
2033         "AllToAll split dimension size %d must be dividable by split_count "
2034         "%d.",
2035         shape.dimensions(split_dimension), split_count);
2036   }
2037   std::vector<int64> new_dimensions(shape.dimensions().begin(),
2038                                     shape.dimensions().end());
2039   new_dimensions[split_dimension] /= split_count;
2040   new_dimensions[concat_dimension] *= split_count;
2041   return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2042 }
2043 
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2044 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2045     absl::Span<const Shape* const> operand_shapes) {
2046   // An Alltoall HLO instruction receives N operands (with the same shape) and
2047   // returns a tuple that contains N array shapes.
2048   TF_RET_CHECK(!operand_shapes.empty());
2049   for (int i = 0; i < operand_shapes.size(); i++) {
2050     if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) {
2051       return InvalidArgument(
2052           "HLO all-to-all has operands with different shapes: the 0th "
2053           "operand shape %s, but the %dth operand has shape %s.",
2054           ShapeUtil::HumanString(*operand_shapes[0]), i,
2055           ShapeUtil::HumanString(*operand_shapes[i]));
2056     }
2057   }
2058 
2059   return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2060 }
2061 
InferCollectivePermuteShape(const Shape & shape)2062 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2063     const Shape& shape) {
2064   TF_RET_CHECK(shape.IsArray());
2065   return shape;
2066 }
2067 
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2068 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2069     absl::Span<const Shape* const> arg_shapes,
2070     absl::Span<const int64> dimensions_to_reduce,
2071     const ProgramShape& to_apply) {
2072   if (arg_shapes.empty()) {
2073     return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2074   }
2075   if (arg_shapes.size() % 2) {
2076     return InvalidArgument(
2077         "Reduce must have an even number of arguments, has %lu",
2078         arg_shapes.size());
2079   }
2080   int64 num_reduced_args = arg_shapes.size() / 2;
2081 
2082   auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2083   // Check that all of the reduced tensors have the same dimensions. The element
2084   // types may be different.
2085   for (int64 i = 1; i < num_reduced_args; ++i) {
2086     if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2087       return InvalidArgument(
2088           "All reduced tensors must have the same dimension. Tensor 0 has "
2089           "shape %s, Tensor %d has shape %s",
2090           ShapeUtil::HumanString(*reduced_args[0]), i,
2091           ShapeUtil::HumanString(*reduced_args[i]));
2092     }
2093   }
2094 
2095   // Check that the dimensions to reduce are in-bounds for the given shape.
2096   // We've already verified all reduced tensors have the same dimensions, so it
2097   // doesn't matter which one we choose.
2098   const Shape& arg = *reduced_args[0];
2099   for (int64 dimension : dimensions_to_reduce) {
2100     if (dimension >= arg.rank() || dimension < 0) {
2101       return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2102                              dimension, ShapeUtil::HumanString(arg));
2103     }
2104   }
2105 
2106   auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2107   std::vector<PrimitiveType> element_types;
2108   for (const Shape* arg : reduced_args) {
2109     element_types.push_back(arg->element_type());
2110   }
2111   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2112                                         num_reduced_args));
2113 
2114   std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
2115                                            dimensions_to_reduce.end());
2116   std::vector<int64> new_dimensions;
2117   std::vector<bool> new_is_dynamic;
2118   for (int i = 0; i < arg.rank(); ++i) {
2119     if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2120       new_dimensions.push_back(arg.dimensions(i));
2121       new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2122     }
2123   }
2124 
2125   if (ShapeUtil::IsScalar(to_apply.result())) {
2126     return ShapeUtil::MakeShape(to_apply.result().element_type(),
2127                                 new_dimensions, new_is_dynamic);
2128   } else {
2129     std::vector<Shape> result_subshapes;
2130     for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2131       result_subshapes.push_back(ShapeUtil::MakeShape(
2132           subshape.element_type(), new_dimensions, new_is_dynamic));
2133     }
2134     return ShapeUtil::MakeTupleShape(result_subshapes);
2135   }
2136 }
2137 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2138 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2139     const Shape& operand_shape, const Shape& init_value_shape,
2140     const Window& window, const ProgramShape& to_apply_shape) {
2141   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2142                                         {operand_shape.element_type()},
2143                                         /*inputs=*/1));
2144   return InferReduceWindowShape(operand_shape, init_value_shape, window);
2145 }
2146 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2147 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2148     const Shape& operand_shape, const Shape& init_value_shape,
2149     const Window& window) {
2150   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2151   return InferWindowOutputShape(operand_shape, window,
2152                                 init_value_shape.element_type(),
2153                                 /*allow_negative_padding=*/false);
2154 }
2155 
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2156 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2157     const Shape& operand_shape, const ProgramShape& select_shape,
2158     const Window& window, const Shape& source_shape,
2159     const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2160   TF_RETURN_IF_ERROR(
2161       ExpectArray(operand_shape, "operand of select-and-scatter"));
2162 
2163   // Check if the select function has a proper shape of (T,T) -> PRED.
2164   if (select_shape.parameters_size() != 2) {
2165     return InvalidArgument(
2166         "Select function must take 2 parameters, but "
2167         "takes %d parameter(s).",
2168         select_shape.parameters_size());
2169   }
2170   const Shape& select_result_shape = select_shape.result();
2171   if (!ShapeUtil::Compatible(select_result_shape,
2172                              ShapeUtil::MakeShape(PRED, {}))) {
2173     return InvalidArgument("Select function must have rank-0 PRED result.");
2174   }
2175   const Shape& operand_element_shape =
2176       ShapeUtil::MakeShape(operand_shape.element_type(), {});
2177   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2178                                                 select_shape.parameters(0))) {
2179     return InvalidArgument(
2180         "Select function's first parameter shape currently must "
2181         "match the operand element shape, but got %s vs %s.",
2182         ShapeUtil::HumanString(select_shape.parameters(0)),
2183         ShapeUtil::HumanString(operand_element_shape));
2184   }
2185   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2186                                                 select_shape.parameters(1))) {
2187     return InvalidArgument(
2188         "Select function's second parameter shape currently must "
2189         "match the operand element shape, but got %s vs %s.",
2190         ShapeUtil::HumanString(select_shape.parameters(1)),
2191         ShapeUtil::HumanString(operand_element_shape));
2192   }
2193 
2194   // Check if the scatter function has a proper shape as a reduction.
2195   TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2196                                         {source_shape.element_type()},
2197                                         /*inputs=*/1));
2198 
2199   // Check if the result shape of window operation matches the source shape.
2200   TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2201                       InferWindowOutputShape(operand_shape, window,
2202                                              operand_shape.element_type(),
2203                                              /*allow_negative_padding=*/false));
2204   if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2205                                                 window_result_shape)) {
2206     return InvalidArgument(
2207         "Source shape does not match the shape of window-reduced operand: "
2208         "source(%s), window-reduced operand(%s).",
2209         ShapeUtil::HumanString(source_shape),
2210         ShapeUtil::HumanString(window_result_shape));
2211   }
2212 
2213   return operand_shape;
2214 }
2215 
InferGetDimensionSizeShape(const Shape & shape,int64 dimension)2216 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2217     const Shape& shape, int64 dimension) {
2218   if (dimension < 0 || dimension >= shape.rank()) {
2219     return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2220                            dimension);
2221   }
2222 
2223   // TODO(b/119580730): Remove this restriction when very large dimension size
2224   // is needed.
2225   if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2226     return InvalidArgument(
2227         "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2228         "INT_MAX limit.",
2229         ShapeUtil::HumanString(shape), dimension);
2230   }
2231 
2232   return ShapeUtil::MakeShape(S32, {});
2233 }
2234 
InferSetDimensionSizeShape(const Shape & shape,int64 dimension)2235 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2236     const Shape& shape, int64 dimension) {
2237   if (dimension < 0 || dimension >= shape.rank()) {
2238     return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2239                            dimension);
2240   }
2241 
2242   // TODO(b/119580730): Remove this restriction when very large dimension size
2243   // is needed.
2244   if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2245     return InvalidArgument(
2246         "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2247         "INT_MAX limit.",
2248         ShapeUtil::HumanString(shape), dimension);
2249   }
2250 
2251   Shape result = shape;
2252   result.set_dynamic_dimension(dimension, true);
2253   return result;
2254 }
2255 
InferWindowFromDimensions(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation)2256 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2257     absl::Span<const int64> window_dimensions,
2258     absl::Span<const int64> window_strides,
2259     absl::Span<const std::pair<int64, int64>> padding,
2260     absl::Span<const int64> lhs_dilation,
2261     absl::Span<const int64> rhs_dilation) {
2262   const auto verify_size = [&](const size_t x, const char* x_name) {
2263     if (x == 0 || x == window_dimensions.size()) {
2264       return Status::OK();
2265     } else {
2266       return InvalidArgument(
2267           "%s", absl::StrCat(
2268                     "Window has different number of window dimensions than of ",
2269                     x_name,
2270                     "\nNumber of window dimensions: ", window_dimensions.size(),
2271                     "\nNumber of ", x_name, ": ", x, "\n"));
2272     }
2273   };
2274   TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2275   TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2276   TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2277   TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2278 
2279   Window window;
2280   for (size_t i = 0; i < window_dimensions.size(); i++) {
2281     auto dim = window.add_dimensions();
2282     dim->set_size(window_dimensions[i]);
2283     if (!window_strides.empty()) {
2284       dim->set_stride(window_strides[i]);
2285     } else {
2286       dim->set_stride(1);
2287     }
2288     if (!padding.empty()) {
2289       dim->set_padding_low(padding[i].first);
2290       dim->set_padding_high(padding[i].second);
2291     } else {
2292       dim->set_padding_low(0);
2293       dim->set_padding_high(0);
2294     }
2295     if (!lhs_dilation.empty()) {
2296       dim->set_base_dilation(lhs_dilation[i]);
2297     } else {
2298       dim->set_base_dilation(1);
2299     }
2300     if (!rhs_dilation.empty()) {
2301       dim->set_window_dilation(rhs_dilation[i]);
2302     } else {
2303       dim->set_window_dilation(1);
2304     }
2305     dim->set_window_reversal(false);
2306   }
2307   return window;
2308 }
2309 
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2310 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2311     const Shape& arg, absl::Span<const int64> starts,
2312     absl::Span<const int64> limits, absl::Span<const int64> strides) {
2313   auto error = [&](const string& message) {
2314     return InvalidArgument(
2315         "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2316         "{%s}; strides: {%s}.",
2317         message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2318         StrJoin(limits, ","), StrJoin(strides, ","));
2319   };
2320   TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2321   VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2322                        ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2323                        StrJoin(limits, ", "));
2324 
2325   if (starts.size() != limits.size()) {
2326     return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2327                            starts.size(), limits.size()));
2328   }
2329 
2330   if (starts.size() != strides.size()) {
2331     return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2332                            starts.size(), strides.size()));
2333   }
2334 
2335   if (starts.size() != arg.rank()) {
2336     return InvalidArgument(
2337         "Slice index count does not match argument rank: %u vs %d.",
2338         starts.size(), arg.rank());
2339   }
2340 
2341   std::vector<int64> sizes;
2342   for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
2343     int64 start_index = starts[dimension];
2344     int64 limit_index = limits[dimension];
2345     int64 stride = strides[dimension];
2346     if (start_index < 0) {
2347       return InvalidArgument("Negative start index to slice: %d.", start_index);
2348     }
2349     if (limit_index > arg.dimensions(dimension)) {
2350       return error(
2351           StrFormat("limit index (%d) must be less than or equal to dimension "
2352                     "size (%d)",
2353                     limit_index, arg.dimensions(dimension)));
2354     }
2355     VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2356     VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2357     if (start_index > limit_index) {
2358       return error(
2359           StrFormat("limit index (%d) must be greater or equal to "
2360                     "start index (%d) in slice with positive stride",
2361                     limit_index, start_index));
2362     }
2363     if (stride <= 0) {
2364       return InvalidArgument("Stride (%d) must be positive.", stride);
2365     }
2366     sizes.push_back((limit_index - start_index + stride - 1) / stride);
2367   }
2368 
2369   std::vector<bool> is_dynamic(arg.rank());
2370   for (int64 i = 0; i < arg.dimensions_size(); ++i) {
2371     // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2372     if (sizes[i] == 1) {
2373       continue;
2374     }
2375     is_dynamic[i] = arg.is_dynamic_dimension(i);
2376   }
2377 
2378   return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2379 }
2380 
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2381 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2382     const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2383     absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2384   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2385   auto number_of_indices = start_index_shapes.size();
2386   // TODO(b/118437727): Remove this path.
2387   if (!allow_scalar_indices ||
2388       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2389     if (number_of_indices != 1) {
2390       return InvalidArgument(
2391           "Dynamic slice should have exactly 1 index operand, has %d.",
2392           number_of_indices);
2393     }
2394 
2395     const Shape& start_indices_shape = start_index_shapes[0];
2396     VLOG(2) << StrFormat(
2397         "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2398         ShapeUtil::HumanString(operand_shape),
2399         ShapeUtil::HumanString(start_indices_shape),
2400         StrJoin(slice_sizes, ", "));
2401 
2402     TF_RETURN_IF_ERROR(
2403         ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2404 
2405     if (start_indices_shape.rank() != 1) {
2406       return InvalidArgument(
2407           "Dynamic slice start indices of rank %d must be rank1.",
2408           start_indices_shape.rank());
2409     }
2410 
2411     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2412       return InvalidArgument(
2413           "Dynamic slice start indices must be of integral type.");
2414     }
2415 
2416     const int64 start_num_dims = start_indices_shape.dimensions(0);
2417     if (operand_shape.rank() != start_num_dims) {
2418       return InvalidArgument(
2419           "Dynamic slice start number of dimensions %d (%s) must match rank "
2420           "%d of slice input (%s).",
2421           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2422           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2423     }
2424   } else {
2425     VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2426                          ShapeUtil::HumanString(operand_shape),
2427                          StrJoin(slice_sizes, ", "));
2428 
2429     if (operand_shape.rank() != number_of_indices) {
2430       return InvalidArgument(
2431           "Dynamic slice start number of dimensions %d must match rank "
2432           "%d of slice input (%s).",
2433           number_of_indices, operand_shape.rank(),
2434           ShapeUtil::HumanString(operand_shape));
2435     }
2436 
2437     if (number_of_indices > 0) {
2438       const Shape& first_index_shape = start_index_shapes[0];
2439       if (!ShapeUtil::IsScalar(first_index_shape)) {
2440         return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2441                                ShapeUtil::HumanString(first_index_shape));
2442       }
2443       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2444         return InvalidArgument(
2445             "Dynamic slice start indices must be of integral type.");
2446       }
2447       for (const Shape& index_shape : start_index_shapes) {
2448         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2449           return InvalidArgument(
2450               "Dynamic slice start indices must all have the same shape, got "
2451               "mismatching indices with shapes %s and %s.",
2452               ShapeUtil::HumanString(first_index_shape),
2453               ShapeUtil::HumanString(index_shape));
2454         }
2455       }
2456     }
2457   }
2458 
2459   if (slice_sizes.size() != operand_shape.rank()) {
2460     return InvalidArgument(
2461         "Dynamic slice index count does not match argument rank: %u vs %d.",
2462         slice_sizes.size(), operand_shape.rank());
2463   }
2464 
2465   for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2466     const int64 input_dim_size = operand_shape.dimensions(dim);
2467     const int64 slice_dim_size = slice_sizes[dim];
2468     if (slice_dim_size < 0) {
2469       return InvalidArgument("Negative size index to dynamic slice: %d.",
2470                              slice_dim_size);
2471     }
2472     if (slice_dim_size > input_dim_size) {
2473       return InvalidArgument(
2474           "Slice dim size %d greater than dynamic slice dimension: %d.",
2475           slice_dim_size, input_dim_size);
2476     }
2477     VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2478   }
2479 
2480   return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2481 }
2482 
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2483 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2484     const Shape& operand_shape, const Shape& update_shape,
2485     absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2486   TF_RETURN_IF_ERROR(
2487       ExpectArray(operand_shape, "operand of dynamic update slice"));
2488   TF_RETURN_IF_ERROR(
2489       ExpectArray(update_shape, "update of dynamic update slice"));
2490 
2491   auto number_of_indices = start_index_shapes.size();
2492   // TODO(b/118437727): Remove this path.
2493   if (!allow_scalar_indices ||
2494       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2495     if (number_of_indices != 1) {
2496       return InvalidArgument(
2497           "Dynamic update slice should have exactly 1 index operand, has %d.",
2498           number_of_indices);
2499     }
2500     const Shape& start_indices_shape = start_index_shapes[0];
2501     TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2502                                    "start indices of dynamic update slice"));
2503 
2504     VLOG(2) << StrFormat(
2505         "updating slice of shape %s at dynamic start_indices %s with update "
2506         "shape %s",
2507         ShapeUtil::HumanString(operand_shape),
2508         ShapeUtil::HumanString(start_indices_shape),
2509         ShapeUtil::HumanString(update_shape));
2510 
2511     if (start_indices_shape.rank() != 1) {
2512       return InvalidArgument(
2513           "Dynamic update slice start indices of rank %d must be rank1.",
2514           start_indices_shape.rank());
2515     }
2516 
2517     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2518       return InvalidArgument(
2519           "Dynamic update slice start indices must be of integral type.");
2520     }
2521 
2522     const int64 start_num_dims = start_indices_shape.dimensions(0);
2523     if (operand_shape.rank() != start_num_dims) {
2524       return InvalidArgument(
2525           "Dynamic update slice start number of dimensions %d (%s) must match "
2526           "rank %d of slice input (%s).",
2527           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2528           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2529     }
2530   } else {
2531     VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2532                          ShapeUtil::HumanString(operand_shape),
2533                          ShapeUtil::HumanString(update_shape));
2534 
2535     if (operand_shape.rank() != number_of_indices) {
2536       return InvalidArgument(
2537           "Dynamic update slice start number of dimensions %d must match "
2538           "rank %d of slice input (%s).",
2539           number_of_indices, operand_shape.rank(),
2540           ShapeUtil::HumanString(operand_shape));
2541     }
2542 
2543     if (number_of_indices > 0) {
2544       const Shape& first_index_shape = start_index_shapes[0];
2545       if (!ShapeUtil::IsScalar(first_index_shape)) {
2546         return InvalidArgument(
2547             "Dynamic update slice indices must be scalar, not %s.",
2548             ShapeUtil::HumanString(first_index_shape));
2549       }
2550       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2551         return InvalidArgument(
2552             "Dynamic update slice start indices must be of integral type.");
2553       }
2554       for (const Shape& index_shape : start_index_shapes) {
2555         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2556           return InvalidArgument(
2557               "Dynamic update slice start indices must all have the same "
2558               "shape, got mismatching indices with shapes %s and %s.",
2559               ShapeUtil::HumanString(first_index_shape),
2560               ShapeUtil::HumanString(index_shape));
2561         }
2562       }
2563     }
2564   }
2565 
2566   if (update_shape.rank() != operand_shape.rank()) {
2567     return InvalidArgument(
2568         "Dynamic update slice update rank does not match argument rank: "
2569         "%d vs %d.",
2570         update_shape.rank(), operand_shape.rank());
2571   }
2572 
2573   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2574                                                      update_shape)) {
2575     return InvalidArgument(
2576         "Dynamic update slice update element type does not match argument. "
2577         "operand.element_type: %s vs update.element_type: %s.",
2578         PrimitiveType_Name(operand_shape.element_type()),
2579         PrimitiveType_Name(update_shape.element_type()));
2580   }
2581 
2582   for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
2583     const int64 input_dim_size = operand_shape.dimensions(dim);
2584     const int64 update_dim_size = update_shape.dimensions(dim);
2585     if (update_dim_size < 0) {
2586       return InvalidArgument(
2587           "Size index %d to dynamic update slice must be >= 0.",
2588           update_dim_size);
2589     }
2590     if (update_dim_size > input_dim_size) {
2591       return InvalidArgument(
2592           "Update dim size %d greater than dynamic slice dimension: %d.",
2593           update_dim_size, input_dim_size);
2594     }
2595     VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2596   }
2597 
2598   return operand_shape;
2599 }
2600 
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2601 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2602     const Shape& operand_shape, absl::Span<const int64> dimensions) {
2603   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2604   if (!AllUnique(dimensions)) {
2605     return InvalidArgument("a dimension number is duplicated in reverse");
2606   }
2607   for (int64 dimension : dimensions) {
2608     if (dimension >= operand_shape.rank() || dimension < 0) {
2609       return InvalidArgument(
2610           "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2611           dimension, ShapeUtil::HumanString(operand_shape));
2612     }
2613   }
2614   return operand_shape;
2615 }
2616 
InferGetTupleElementShape(const Shape & arg,int64 index)2617 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2618     const Shape& arg, int64 index) {
2619   if (!arg.IsTuple()) {
2620     return InvalidArgument(
2621         "Cannot infer shape: attempting to index into non-tuple: %s.",
2622         ShapeUtil::HumanString(arg));
2623   }
2624 
2625   if (index < 0 || index >= arg.tuple_shapes_size()) {
2626     return InvalidArgument(
2627         "Cannot infer shape: attempt to index out of tuple bounds: %d "
2628         ">= %d in shape %s.",
2629         index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2630   }
2631 
2632   return arg.tuple_shapes(index);
2633 }
2634 
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2635 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2636     const ProgramShape& condition, const ProgramShape& body,
2637     const Shape& init) {
2638   // Check the number of parameters for given computations.
2639   if (condition.parameters_size() != 1) {
2640     return InvalidArgument("Condition must take 1 arguments; got %d.",
2641                            condition.parameters_size());
2642   }
2643   if (body.parameters_size() != 1) {
2644     return InvalidArgument("Body must take 1 arguments; got %d.",
2645                            body.parameters_size());
2646   }
2647 
2648   auto shape_string = [&]() {
2649     return StrFormat(
2650         "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2651         ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2652   };
2653 
2654   // Check the shapes of computation parameters and return types.
2655   if (!ShapeUtil::Compatible(condition.result(),
2656                              ShapeUtil::MakeShape(PRED, {}))) {
2657     return InvalidArgument("Condition must return a boolean; got %s.",
2658                            shape_string());
2659   }
2660   if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2661       !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2662       !ShapeUtil::Compatible(body.result(), init)) {
2663     return InvalidArgument(
2664         "The parameter of condition and body, the result of the body, and init "
2665         "must all have the same shape; got %s.",
2666         shape_string());
2667   }
2668 
2669   return init;
2670 }
2671 
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2672 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2673     const Shape& branch_index,
2674     absl::Span<const ProgramShape> branch_computations,
2675     absl::Span<const Shape> branch_operands) {
2676   if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2677       !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2678     return InvalidArgument("branch_index must be bool or int32; got %s.",
2679                            ShapeUtil::HumanString(branch_index));
2680   }
2681   if (branch_index.element_type() == PRED) {
2682     TF_RET_CHECK(2 == branch_computations.size());
2683   } else {
2684     TF_RET_CHECK(!branch_computations.empty());
2685   }
2686   TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2687 
2688   for (int j = 0; j < branch_computations.size(); ++j) {
2689     if (branch_computations[j].parameters_size() != 1) {
2690       return InvalidArgument(
2691           "branch computation %d must take 1 argument; got %d.", j,
2692           branch_computations[j].parameters_size());
2693     }
2694     if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2695                                branch_operands[j])) {
2696       auto shape_string = [&]() {
2697         return StrFormat("operand: %s; computation: %s",
2698                          ShapeUtil::HumanString(branch_operands[j]),
2699                          ShapeUtil::HumanString(branch_computations[j]));
2700       };
2701       return InvalidArgument(
2702           "branch operand %d must match the shape of the only parameter of "
2703           "branch computation %d: got %s.",
2704           j, j, shape_string());
2705     }
2706 
2707     if (!ShapeUtil::Compatible(branch_computations[0].result(),
2708                                branch_computations[j].result())) {
2709       auto shape_string = [&]() {
2710         return StrFormat(
2711             "branch 0 computation result: %s; branch %d computation result: %s",
2712             ShapeUtil::HumanString(branch_computations[0].result()), j,
2713             ShapeUtil::HumanString(branch_computations[j].result()));
2714       };
2715       return InvalidArgument(
2716           "the result of branch 0 computation and branch %d computation must "
2717           "have the same shape: got %s.",
2718           j, shape_string());
2719     }
2720   }
2721   return branch_computations[0].result();
2722 }
2723 
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2724 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2725     const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2726   TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2727   for (int64 size : broadcast_sizes) {
2728     if (size < 0) {
2729       return InvalidArgument("Broadcast with negative dimension size %d.",
2730                              size);
2731     }
2732   }
2733 
2734   std::vector<int64> dimensions(operand.dimensions_size() +
2735                                 broadcast_sizes.size());
2736   std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2737   std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2738             dimensions.begin() + broadcast_sizes.size());
2739 
2740   Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2741   for (int64 i = 0; i < operand.dimensions_size(); ++i) {
2742     result.set_dynamic_dimension(broadcast_sizes.size() + i,
2743                                  operand.is_dynamic_dimension(i));
2744   }
2745   return result;
2746 }
2747 
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2748 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2749     const Shape& operand_shape, const Shape& output_shape,
2750     absl::Span<const int64> broadcast_dimensions) {
2751   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
2752   TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
2753   const int64 operand_rank = operand_shape.rank();
2754   const int64 output_rank = output_shape.rank();
2755   if (operand_rank > output_rank) {
2756     return InvalidArgument(
2757         "InDim style broadcast must be to an equal or higher ranked shape; "
2758         "operand rank: %lld; output rank: %lld",
2759         operand_rank, output_rank);
2760   }
2761   if (operand_rank != broadcast_dimensions.size()) {
2762     return InvalidArgument(
2763         "Size of broadcast_dimensions has to match operand's rank; operand "
2764         "rank: %lld, size of broadcast_dimensions %u.",
2765         operand_rank, broadcast_dimensions.size());
2766   }
2767   for (int64 i = 0; i < operand_rank; i++) {
2768     if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
2769       return InvalidArgument("Broadcast dimension %lld is out of bound",
2770                              broadcast_dimensions[i]);
2771     }
2772     if (operand_shape.dimensions(i) !=
2773             output_shape.dimensions(broadcast_dimensions[i]) &&
2774         operand_shape.dimensions(i) != 1) {
2775       return InvalidArgument(
2776           "Input dimension should be either 1 or equal to the output dimension "
2777           "it is broadcasting into; the %lldth operand dimension is %lld, the "
2778           "%lldth output dimension is %lld.",
2779           i, operand_shape.dimensions(i), broadcast_dimensions[i],
2780           output_shape.dimensions(broadcast_dimensions[i]));
2781     }
2782     if (operand_shape.is_dynamic_dimension(i) !=
2783         output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
2784       return InvalidArgument(
2785           "Broadcast input and output dynamism mismatch: %s and %s",
2786           operand_shape.ToString(), output_shape.ToString());
2787     }
2788     // Make sure the broadcast dimensions are listed in a strictly increasing
2789     // order.
2790     if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
2791       return InvalidArgument(
2792           "Broadcast dimensions order is wrong: %d comes after %d.",
2793           broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
2794     }
2795   }
2796 
2797   return output_shape;
2798 }
2799 
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)2800 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2801     const Shape& operand, absl::Span<const int64> dimensions,
2802     absl::Span<const int64> new_sizes, int64 inferred_dimension) {
2803   TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
2804 
2805   Shape inferred_shape =
2806       ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2807   VLOG(3) << "Reshape inferred shape: "
2808           << ShapeUtil::HumanString(inferred_shape);
2809 
2810   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2811     return InvalidArgument(
2812         "Reshape operation has mismatched element counts: from=%d (%s) "
2813         "to=%d (%s).",
2814         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2815         ShapeUtil::ElementsIn(inferred_shape),
2816         ShapeUtil::HumanString(inferred_shape));
2817   }
2818 
2819   std::vector<int64> indices(operand.rank());
2820   std::iota(indices.begin(), indices.end(), 0);
2821   if (dimensions.size() != operand.rank() ||
2822       !std::is_permutation(dimensions.begin(), dimensions.end(),
2823                            indices.begin())) {
2824     return InvalidArgument(
2825         "Reshape dimensions [%s] are not a permutation of the operand "
2826         "dimensions (operand shape is %s).",
2827         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2828   }
2829 
2830   // Propagate dynamic dimension.
2831   auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
2832   for (int64 input_dim = 0; input_dim < operand.rank(); ++input_dim) {
2833     if (!operand.is_dynamic_dimension(input_dim)) {
2834       continue;
2835     }
2836 
2837     string reshape_debug_str = absl::StrFormat(
2838         "output: %s, input: %s, input_dim: "
2839         "%lld",
2840         ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
2841         input_dim);
2842 
2843     int64 input_dim_start = -1;
2844     int64 input_dim_end = -1;
2845     int64 output_dim_start = -1;
2846     int64 output_dim_end = -1;
2847     // Find common_factors that the input_dim belongs to.
2848     for (int64 i = 0; i < common_factors.size() - 1; ++i) {
2849       auto start = common_factors[i];
2850       auto end = common_factors[i + 1];
2851       if (input_dim >= start.first && input_dim < end.first) {
2852         input_dim_start = start.first;
2853         input_dim_end = end.first;
2854         output_dim_start = start.second;
2855         output_dim_end = end.second;
2856         break;
2857       }
2858     }
2859     if ((input_dim_end - input_dim_start) > 1 &&
2860         (output_dim_end - output_dim_start) > 1) {
2861       // We don't support the case when a dynamic dimension is both combined
2862       // with and splitted into other dimensions:
2863       //
2864       //  [x, yz]
2865       //     | Reshape
2866       //  [xy, z]
2867       return Unimplemented(
2868           "Dynamic input dimension to reshape that is both splitted and "
2869           "combined is not supported: %s",
2870           reshape_debug_str);
2871     }
2872 
2873     for (auto common_factor : common_factors) {
2874       //
2875       // For reshapes like:
2876       //  [<=5]
2877       //    | Reshape
2878       //  [1, 5]
2879       //
2880       //  The input dynamic dimension can go into either dynamic dimensions.
2881       //  However, the return value of common factors only returns
2882       //  input: 5
2883       //  output: 5
2884       //
2885       //  We need to expand common factor to include degenerated output
2886       //  dimensions:
2887       //  input: 5
2888       //  output: 1, 5
2889       //
2890       //  such that in the logic later on we can consider both dimensions as
2891       //  candidate.
2892       if (common_factor.first == input_dim_start) {
2893         output_dim_start = std::min(output_dim_start, common_factor.second);
2894       }
2895       if (common_factor.first == input_dim_end) {
2896         output_dim_end = std::max(output_dim_end, common_factor.second);
2897       }
2898     }
2899 
2900     // Calculate output dynamic reshape dimension.
2901     int64 output_dynamic_dimension = -1;
2902 
2903     if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
2904       // If dynamic dimension is size 1, it can only be most-major or
2905       // most-minor.
2906       if (input_dim == 0) {
2907         output_dynamic_dimension = 0;
2908       }
2909       if (input_dim == operand.rank() - 1) {
2910         output_dynamic_dimension = new_sizes.size() - 1;
2911       }
2912 
2913       if (output_dynamic_dimension == -1) {
2914         return Unimplemented(
2915             "Dynamic degenerated dimension that's not most-minor nor "
2916             "most-major is not supported: %s",
2917             reshape_debug_str);
2918       }
2919     }
2920 
2921     if (output_dynamic_dimension == -1 &&
2922         output_dim_end - output_dim_start == 1) {
2923       // Only one possible output dimension.
2924       output_dynamic_dimension = output_dim_start;
2925     }
2926     if (output_dynamic_dimension == -1 &&
2927         output_dim_end - output_dim_start > 1) {
2928       // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
2929       output_dynamic_dimension = inferred_dimension;
2930     }
2931 
2932     if (output_dynamic_dimension != -1) {
2933       // TODO(yunxing): Turn this into a CHECK.
2934       inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
2935     } else {
2936       std::vector<int64> output_non_degenerated;
2937       for (int64 i = output_dim_start; i < output_dim_end; ++i) {
2938         if (new_sizes[i] != 1) {
2939           output_non_degenerated.push_back(i);
2940         }
2941       }
2942       if (output_non_degenerated.size() == 1) {
2943         inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
2944       }
2945     }
2946   }
2947 
2948   return inferred_shape;
2949 }
2950 
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)2951 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
2952     const Shape& operand, absl::Span<const int64> dimensions) {
2953   TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
2954 
2955   if (!IsPermutation(dimensions, operand.rank())) {
2956     return InvalidArgument(
2957         "Transpose dimensions [%s] are not a permutation of the operand "
2958         "dimensions (operand shape is %s).",
2959         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2960   }
2961 
2962   // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
2963   // we need output[i]=input[dimensions[i]] which is
2964   // Permute(Inverse(dimensions),input).
2965   return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
2966 }
2967 
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)2968 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
2969     const Shape& min, const Shape& operand, const Shape& max) {
2970   TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
2971   TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
2972   TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
2973 
2974   if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
2975       !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
2976     return InvalidArgument(
2977         "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
2978         ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
2979   }
2980   return operand;
2981 }
2982 
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2983 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
2984     const Shape& pred, const Shape& on_true, const Shape& on_false) {
2985   TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
2986   TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
2987   TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
2988 
2989   if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
2990     return InvalidArgument(
2991         "Operands to select must be the same shape; got %s and %s.",
2992         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
2993   }
2994   if (pred.element_type() != PRED) {
2995     return InvalidArgument(
2996         "Select's pred operand must have PRED element type; got %s.",
2997         ShapeUtil::HumanString(pred));
2998   }
2999   if (!Shape::Equal()
3000            .IgnoreElementType()
3001            .IgnoreLayout()
3002            .IgnoreDynamicDimension()(pred, on_true)) {
3003     return InvalidArgument(
3004         "Operands to select and predicate must be the same shape; got %s and "
3005         "%s.",
3006         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3007   }
3008 
3009   return ShapeUtil::ChangeElementType(
3010       pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3011 }
3012 
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3013 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
3014     const Shape& pred, const Shape& on_true, const Shape& on_false) {
3015   // Select only defines the top-level buffer, so if it's a tuple, the two
3016   // input must match exactly.
3017   if (!ShapeUtil::Compatible(on_true, on_false)) {
3018     return InvalidArgument(
3019         "Operands to tuple-select must be the same shape; got %s and %s.",
3020         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3021   }
3022   if (pred.element_type() != PRED) {
3023     return InvalidArgument(
3024         "TupleSelect's pred operand must have PRED element type; got %s.",
3025         ShapeUtil::HumanString(pred));
3026   }
3027   if (!ShapeUtil::IsScalar(pred)) {
3028     return InvalidArgument(
3029         "TupleSelect operation with non-scalar predicate: %s.",
3030         ShapeUtil::HumanString(pred));
3031   }
3032   return on_true;
3033 }
3034 
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3035 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3036     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3037   // The applied function's arity equals the number of arguments.
3038   if (arg_shapes.size() != to_apply.parameters_size()) {
3039     string computation_signature = ShapeUtil::HumanString(to_apply);
3040     string argument_shapes =
3041         StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
3042           absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3043         });
3044     return InvalidArgument(
3045         "Call applied function arity must match number of arguments; got: "
3046         "arity: %d, arguments: %u; computation signature: %s; argument "
3047         "shapes: [%s].",
3048         to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3049         argument_shapes);
3050   }
3051 
3052   // All arguments must be compatible with the program shape.
3053   for (int i = 0; i < arg_shapes.size(); ++i) {
3054     const Shape& arg_shape = *arg_shapes[i];
3055     const Shape& param_shape = to_apply.parameters(i);
3056     if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3057       return InvalidArgument(
3058           "Call parameter must match argument; got parameter %d shape: %s, "
3059           "argument shape: %s.",
3060           i, ShapeUtil::HumanString(param_shape),
3061           ShapeUtil::HumanString(arg_shape));
3062     }
3063   }
3064 
3065   return to_apply.result();
3066 }
3067 
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3068 static Status ValidateGatherDimensionNumbers(
3069     const Shape& input_shape, absl::Span<const int64> start_indices_shape,
3070     const GatherDimensionNumbers& dim_numbers) {
3071   if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3072     return InvalidArgument(
3073         "Output window dimensions in gather op must be ascending; got: %s.",
3074         StrJoin(dim_numbers.offset_dims(), ", "));
3075   }
3076 
3077   if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3078       dim_numbers.offset_dims().end()) {
3079     return InvalidArgument(
3080         "Output window dimensions in gather op must not repeat; got: %s.",
3081         StrJoin(dim_numbers.offset_dims(), ", "));
3082   }
3083 
3084   const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
3085   const int64 output_shape_rank =
3086       output_offset_dim_count + start_indices_shape.size() - 1;
3087 
3088   for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3089     int64 offset_dim = dim_numbers.offset_dims(i);
3090     if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3091       return InvalidArgument(
3092           "Offset dimension %d in gather op is out of bounds; got %d, but "
3093           "should "
3094           "have been in [0,%d).",
3095           i, offset_dim, output_shape_rank);
3096     }
3097   }
3098 
3099   if (dim_numbers.start_index_map_size() !=
3100       start_indices_shape[dim_numbers.index_vector_dim()]) {
3101     return InvalidArgument(
3102         "Gather op has %d elements in start_index_map and the "
3103         "bound of dimension index_vector_dim=%d of start_indices is "
3104         "%d. These two numbers must be equal.",
3105         dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3106         start_indices_shape[dim_numbers.index_vector_dim()]);
3107   }
3108 
3109   for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3110     int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3111     if (operand_dim_for_start_index_i < 0 ||
3112         operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3113       return InvalidArgument(
3114           "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3115           input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3116     }
3117   }
3118 
3119   std::vector<int64> sorted_start_index_map(
3120       dim_numbers.start_index_map().begin(),
3121       dim_numbers.start_index_map().end());
3122 
3123   absl::c_sort(sorted_start_index_map);
3124 
3125   if (absl::c_adjacent_find(sorted_start_index_map) !=
3126       sorted_start_index_map.end()) {
3127     return InvalidArgument(
3128         "Repeated dimensions are not allowed in start_index_map; "
3129         "got: %s.",
3130         StrJoin(dim_numbers.start_index_map(), ", "));
3131   }
3132 
3133   for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3134     if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3135       return InvalidArgument(
3136           "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3137           "%d), got: %d.",
3138           input_shape.dimensions_size(), collapsed_dim);
3139     }
3140   }
3141 
3142   if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3143     return InvalidArgument(
3144         "collapsed_slice_dims in gather op must be sorted; got: %s",
3145         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3146   }
3147 
3148   if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3149       dim_numbers.collapsed_slice_dims().end()) {
3150     return InvalidArgument(
3151         "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3152         "got: %s.",
3153         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3154   }
3155 
3156   return Status::OK();
3157 }
3158 
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)3159 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3160     const Shape& input_shape, const Shape& start_indices_shape,
3161     const GatherDimensionNumbers& gather_dim_numbers,
3162     absl::Span<const int64> slice_sizes) {
3163   TF_RETURN_IF_ERROR(
3164       ExpectArray(input_shape, "input tensor operand of gather op"));
3165   TF_RETURN_IF_ERROR(
3166       ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3167 
3168   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3169     return InvalidArgument(
3170         "Gather indices parameter must be an integral tensor; got %s.",
3171         ShapeUtil::HumanString(start_indices_shape));
3172   }
3173 
3174   // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3175   // index_vector_dim is rank(P).  The bounds of this expanded shape is
3176   // stored in expanded_start_indices_shape.
3177 
3178   if (start_indices_shape.dimensions_size() <
3179           gather_dim_numbers.index_vector_dim() ||
3180       gather_dim_numbers.index_vector_dim() < 0) {
3181     return InvalidArgument(
3182         "Gather index leaf dimension must be within [0, rank(start_indices) + "
3183         "1). rank(start_indices) is %d and gather index leaf dimension is "
3184         "%d.",
3185         start_indices_shape.dimensions_size(),
3186         gather_dim_numbers.index_vector_dim());
3187   }
3188 
3189   std::vector<int64> expanded_start_indices_shape;
3190   // Also tracks if an output dimension is dynamic.
3191   std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3192   expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3193   expanded_start_indices_shape_dynamic_dimensions.reserve(
3194       start_indices_shape.dimensions_size());
3195   absl::c_copy(start_indices_shape.dimensions(),
3196                std::back_inserter(expanded_start_indices_shape));
3197   absl::c_copy(
3198       start_indices_shape.dynamic_dimensions(),
3199       std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3200   if (expanded_start_indices_shape.size() ==
3201       gather_dim_numbers.index_vector_dim()) {
3202     expanded_start_indices_shape.push_back(1);
3203     expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3204   }
3205 
3206   TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3207       input_shape, expanded_start_indices_shape, gather_dim_numbers));
3208 
3209   if (slice_sizes.size() != input_shape.dimensions_size()) {
3210     return InvalidArgument(
3211         "Gather op must have one slice size for every input dimension; got: "
3212         "len(slice_sizes)=%lu, input_shape.rank=%d.",
3213         slice_sizes.size(), input_shape.dimensions_size());
3214   }
3215 
3216   if (slice_sizes.size() !=
3217       gather_dim_numbers.offset_dims_size() +
3218           gather_dim_numbers.collapsed_slice_dims_size()) {
3219     return InvalidArgument(
3220         "All components of the offset index in a gather op must either be a "
3221         "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3222         "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3223         slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3224         StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3225   }
3226 
3227   for (int i = 0; i < slice_sizes.size(); i++) {
3228     int64 slice_size = slice_sizes[i];
3229     int64 corresponding_input_size = input_shape.dimensions(i);
3230     if (slice_size < 0 || slice_size > corresponding_input_size) {
3231       return InvalidArgument(
3232           "Slice size at index %d in gather op is out of range, must be "
3233           "within [0, %d), got %d.",
3234           i, corresponding_input_size + 1, slice_size);
3235     }
3236   }
3237 
3238   for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3239     if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3240       return InvalidArgument(
3241           "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3242           "is %d for index %d at position %d.",
3243           slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3244           gather_dim_numbers.collapsed_slice_dims(i), i);
3245     }
3246   }
3247 
3248   int64 result_rank = gather_dim_numbers.offset_dims_size() +
3249                       (expanded_start_indices_shape.size() - 1);
3250   int64 offset_dims_seen = 0;
3251   int64 gather_dims_seen = 0;
3252   std::vector<int64> output_dim_bounds;
3253   output_dim_bounds.reserve(result_rank);
3254 
3255   std::vector<bool> output_dim_is_dynamic;
3256   output_dim_is_dynamic.reserve(result_rank);
3257   for (int64 i = 0; i < result_rank; i++) {
3258     int64 current_bound;
3259     bool dim_dynamic = false;
3260     bool is_window_index =
3261         absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3262     if (is_window_index) {
3263       while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3264                                    offset_dims_seen)) {
3265         offset_dims_seen++;
3266       }
3267       // Gathering an entire dynamic dimension creates dynamic dimension.
3268       //
3269       // e.g.,:
3270       //
3271       // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3272       //
3273       // creates
3274       //
3275       // [<=2, 1]
3276       if (slice_sizes[offset_dims_seen] ==
3277           input_shape.dimensions(offset_dims_seen)) {
3278         dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3279       }
3280       current_bound = slice_sizes[offset_dims_seen++];
3281     } else {
3282       if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3283         gather_dims_seen++;
3284       }
3285       // Forward dynamic dimensions from indices.
3286       dim_dynamic =
3287           expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3288 
3289       current_bound = expanded_start_indices_shape[gather_dims_seen++];
3290     }
3291     output_dim_is_dynamic.push_back(dim_dynamic);
3292     output_dim_bounds.push_back(current_bound);
3293   }
3294 
3295   return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3296                               output_dim_is_dynamic);
3297 }
3298 
3299 namespace {
3300 
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3301 Status ValidateScatterDimensionNumbers(
3302     const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3303     const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3304   // Validate update_window_dims in ScatterDimensionNumbers.
3305   if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3306     return InvalidArgument(
3307         "update_window_dims in scatter op must be sorted; got: %s.",
3308         StrJoin(dim_numbers.update_window_dims(), ", "));
3309   }
3310   if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3311       dim_numbers.update_window_dims().end()) {
3312     return InvalidArgument(
3313         "update_window_dims in scatter op must not repeat; got: %s.",
3314         StrJoin(dim_numbers.update_window_dims(), ", "));
3315   }
3316   const int64 updates_rank = updates_shape.rank();
3317   for (int64 window_dim : dim_numbers.update_window_dims()) {
3318     if (window_dim < 0 || window_dim >= updates_rank) {
3319       return InvalidArgument(
3320           "Invalid update_window_dims set in scatter op; valid range is [0, "
3321           "%d). got: %d.",
3322           updates_rank, window_dim);
3323     }
3324   }
3325 
3326   // Validate inserted_window_dims in ScatterDimensionNumbers.
3327   if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3328     return InvalidArgument(
3329         "inserted_window_dims in scatter op must be sorted; got: %s.",
3330         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3331   }
3332   if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3333       dim_numbers.inserted_window_dims().end()) {
3334     return InvalidArgument(
3335         "inserted_window_dims in scatter op must not repeat; got: %s.",
3336         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3337   }
3338   for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
3339     if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3340       return InvalidArgument(
3341           "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3342           "%d), got: %d.",
3343           operand_shape.dimensions_size(), inserted_dim);
3344     }
3345   }
3346 
3347   // Validate window size.
3348   auto window_size = dim_numbers.update_window_dims_size() +
3349                      dim_numbers.inserted_window_dims_size();
3350   if (window_size != operand_shape.rank()) {
3351     return InvalidArgument(
3352         "Scatter op has window of size %d; doesn't match operand of rank %d.",
3353         window_size, operand_shape.rank());
3354   }
3355 
3356   // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3357   if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3358       scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3359     return InvalidArgument(
3360         "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3361         "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3362         "These two numbers must be equal.",
3363         dim_numbers.scatter_dims_to_operand_dims_size(),
3364         dim_numbers.index_vector_dim(),
3365         scatter_indices_shape[dim_numbers.index_vector_dim()]);
3366   }
3367   for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3368     int64 scatter_dim_to_operand_dim =
3369         dim_numbers.scatter_dims_to_operand_dims(i);
3370     if (scatter_dim_to_operand_dim < 0 ||
3371         scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3372       return InvalidArgument(
3373           "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3374           "got: %d->%d.",
3375           operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3376     }
3377   }
3378   std::vector<int64> sorted_scatter_dims_to_operand_dims(
3379       dim_numbers.scatter_dims_to_operand_dims().begin(),
3380       dim_numbers.scatter_dims_to_operand_dims().end());
3381   absl::c_sort(sorted_scatter_dims_to_operand_dims);
3382   if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3383       sorted_scatter_dims_to_operand_dims.end()) {
3384     return InvalidArgument(
3385         "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3386         "got: %s.",
3387         StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3388   }
3389 
3390   return Status::OK();
3391 }
3392 
3393 }  // namespace
3394 
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3395 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3396     const Shape& operand_shape, const Shape& scatter_indices_shape,
3397     const Shape& updates_shape, const ProgramShape& to_apply_shape,
3398     const ScatterDimensionNumbers& scatter_dim_numbers) {
3399   TF_RETURN_IF_ERROR(
3400       ExpectArray(operand_shape, "operand tensor of scatter op"));
3401   TF_RETURN_IF_ERROR(
3402       ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3403   TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3404 
3405   if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3406     return InvalidArgument(
3407         "Scatter indices parameter must be an integral tensor; got %s.",
3408         ShapeUtil::HumanString(scatter_indices_shape));
3409   }
3410 
3411   if (scatter_indices_shape.dimensions_size() <
3412           scatter_dim_numbers.index_vector_dim() ||
3413       scatter_dim_numbers.index_vector_dim() < 0) {
3414     return InvalidArgument(
3415         "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3416         " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3417         "is %d.",
3418         scatter_indices_shape.dimensions_size(),
3419         scatter_dim_numbers.index_vector_dim());
3420   }
3421 
3422   // Check if the update computation has a proper shape as a reduction.
3423   const Shape init_value_shape =
3424       ShapeUtil::MakeShape(operand_shape.element_type(), {});
3425   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3426                                         {updates_shape.element_type()},
3427                                         /*inputs=*/1));
3428 
3429   std::vector<int64> expanded_scatter_indices_shape =
3430       SpanToVector(scatter_indices_shape.dimensions());
3431   if (expanded_scatter_indices_shape.size() ==
3432       scatter_dim_numbers.index_vector_dim()) {
3433     expanded_scatter_indices_shape.push_back(1);
3434   }
3435 
3436   int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3437                                 scatter_dim_numbers.update_window_dims_size();
3438   if (updates_shape.rank() != expected_updates_rank) {
3439     return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3440                            expected_updates_rank, updates_shape.rank());
3441   }
3442 
3443   TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3444       operand_shape, expanded_scatter_indices_shape, updates_shape,
3445       scatter_dim_numbers));
3446 
3447   int64 inserted_dims_seen = 0;
3448   std::vector<int64> max_update_slice_sizes;
3449   for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3450     if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3451         scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3452       ++inserted_dims_seen;
3453     } else {
3454       max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3455     }
3456   }
3457   for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3458     auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3459     if (updates_shape.dimensions(update_window_dim) >
3460         max_update_slice_sizes[i]) {
3461       return InvalidArgument(
3462           "Bounds of the window dimensions of updates must not exceed the "
3463           "bounds of the corresponding dimensions of operand. For dimension "
3464           "%d, updates bound is %d, operand bound is %d.",
3465           update_window_dim, updates_shape.dimensions(update_window_dim),
3466           max_update_slice_sizes[i]);
3467     }
3468   }
3469 
3470   int64 scatter_dims_seen = 0;
3471   for (int64 i = 0; i < updates_shape.rank(); ++i) {
3472     bool is_update_window_dim =
3473         absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3474     if (is_update_window_dim) {
3475       continue;
3476     }
3477     if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3478       ++scatter_dims_seen;
3479     }
3480     if (updates_shape.dimensions(i) !=
3481         expanded_scatter_indices_shape[scatter_dims_seen]) {
3482       return InvalidArgument(
3483           "Bounds of the scatter dimensions of updates must be same as the "
3484           "bounds of the corresponding dimensions of scatter indices. For "
3485           "scatter dimension %d, updates bound is %d, scatter_indices "
3486           "bound is %d.",
3487           i, updates_shape.dimensions(i),
3488           expanded_scatter_indices_shape[scatter_dims_seen]);
3489     }
3490     ++scatter_dims_seen;
3491   }
3492 
3493   return operand_shape;
3494 }
3495 
3496 }  // namespace xla
3497