• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/permutation_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/math/math_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/statusor.h"
43 
44 namespace xla {
45 namespace {
46 
47 using absl::StrFormat;
48 using absl::StrJoin;
49 
50 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)51 bool AllUnique(absl::Span<const int64> slice) {
52   return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
53 }
54 
ExpectArray(const Shape & shape,absl::string_view op_type)55 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
56   if (!shape.IsArray()) {
57     return InvalidArgument("Expected array argument for %s, but got %s.",
58                            string(op_type), ShapeUtil::HumanString(shape));
59   }
60   return Status::OK();
61 }
62 
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64_t inputs)63 Status VerifyReducerShape(const ProgramShape& reducer_shape,
64                           absl::Span<const Shape* const> init_value_shapes,
65                           absl::Span<const PrimitiveType> input_element_types,
66                           int64_t inputs) {
67   if (reducer_shape.parameters_size() != inputs * 2) {
68     return InvalidArgument(
69         "Reduction function must take %d parameters, but "
70         "takes %d parameter(s).",
71         inputs * 2, reducer_shape.parameters_size());
72   }
73 
74   const Shape& accumulator_shape = reducer_shape.result();
75   std::vector<const Shape*> accumulator_subshapes;
76   if (accumulator_shape.IsArray()) {
77     if (inputs != 1) {
78       return InvalidArgument(
79           "Reduction function must produce a tuple with %d elements, but "
80           "produces a scalar",
81           inputs);
82     }
83     accumulator_subshapes.push_back(&accumulator_shape);
84   } else if (accumulator_shape.IsTuple()) {
85     if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
86       return InvalidArgument(
87           "Reduction function must produce a tuple with %d elements, but has "
88           "%d elements",
89           inputs, ShapeUtil::TupleElementCount(accumulator_shape));
90     }
91     for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
92       accumulator_subshapes.push_back(&element_shape);
93     }
94   } else {
95     return InvalidArgument(
96         "Reduction function must produce a scalar or tuple of scalars, but has "
97         "shape: %s",
98         ShapeUtil::HumanString(accumulator_shape));
99   }
100 
101   for (const Shape* element_shape : accumulator_subshapes) {
102     if (element_shape->rank() != 0) {
103       return InvalidArgument(
104           "Reduction function must return a scalar or tuple of scalars but "
105           "returns shape: %s",
106           ShapeUtil::HumanString(accumulator_shape));
107     }
108   }
109 
110   for (int64_t i = 0; i < inputs; ++i) {
111     // Check that the accumulator can be passed in as the first argument.
112     // Note: comparing here and below with Compatible since we don't care about
113     // layout in scalars - see b/26668201 for a longer-term vision.
114     if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
115                                reducer_shape.parameters(i))) {
116       return InvalidArgument(
117           "Reduction function's %d-th parameter shape differs from the "
118           "result shape: %s vs %s",
119           i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
120           ShapeUtil::HumanString(*accumulator_subshapes[i]));
121     }
122     // Check that init_value's shapes are suitable for reducer_shape.
123     if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
124                                                   *init_value_shapes[i])) {
125       return InvalidArgument(
126           "Reduction function's accumulator shape at index %d differs from "
127           "the init_value shape: %s vs %s",
128           i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
129           ShapeUtil::HumanString(*init_value_shapes[i]));
130     }
131     // Check that the inputs can be passed in as the non-accumulator arguments.
132     const Shape input_element_shape =
133         ShapeUtil::MakeShape(input_element_types[i], {});
134     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
135             input_element_shape, reducer_shape.parameters(inputs + i))) {
136       return InvalidArgument(
137           "Reduction function's %d-th parameter shape differs from the "
138           "input type element type: %s vs %s",
139           inputs + i,
140           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
141           ShapeUtil::HumanString(input_element_shape));
142     }
143     // Check that the accumulator and inputs to the reducer function match.
144     // If the accumulator is scalar, it must have the same type as the inputs
145     // (up to fp precision). If it is a tuple, then the k-th element of the
146     // tuple must have the same type as the K-th input (again, up to fp
147     // precision.)
148     if (!ShapeUtil::CompatibleIgnoringFpPrecision(
149             *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
150       return InvalidArgument(
151           "Reduction function's %d-th parameter shape must "
152           "match the result shape, but got %s vs %s.",
153           inputs + i,
154           ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
155           ShapeUtil::HumanString(*accumulator_subshapes[i]));
156     }
157   }
158 
159   return Status::OK();
160 }
161 
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)162 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
163                                        const Window& window,
164                                        PrimitiveType element_type,
165                                        bool allow_negative_padding) {
166   if (window.dimensions_size() != base_shape.rank()) {
167     return InvalidArgument(
168         "Window has dimension %d but base shape has dimension %d.",
169         window.dimensions_size(), base_shape.rank());
170   }
171 
172   std::vector<int64> output_dimensions(window.dimensions_size());
173   std::vector<bool> output_is_dynamic(window.dimensions_size());
174   for (int64_t i = 0; i < window.dimensions_size(); ++i) {
175     const auto& dim = window.dimensions(i);
176     if (dim.size() <= 0) {
177       return InvalidArgument("Window %s has a non-positive dimension.",
178                              window.DebugString());
179     }
180     if (dim.stride() <= 0) {
181       return InvalidArgument("Window %s has a non-positive stride.",
182                              window.DebugString());
183     }
184     if (!allow_negative_padding && dim.padding_low() < 0) {
185       return InvalidArgument("Window %s has a negative low padding.",
186                              window.DebugString());
187     }
188     if (!allow_negative_padding && dim.padding_high() < 0) {
189       return InvalidArgument("Window %s has a negative high padding.",
190                              window.DebugString());
191     }
192     if (dim.base_dilation() < 1) {
193       return InvalidArgument(
194           "Window %s has a non-positive base area dilation factor.",
195           window.DebugString());
196     }
197     if (dim.window_dilation() < 1) {
198       return InvalidArgument(
199           "Window %s has a non-positive window dilation factor.",
200           window.DebugString());
201     }
202 
203     const int64_t dilated_base = window_util::DilatedBound(
204         ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
205     const int64_t padded_dilated_base =
206         dim.padding_low() + dilated_base + dim.padding_high();
207     const int64_t dilated_window =
208         window_util::DilatedBound(dim.size(), dim.window_dilation());
209 
210     output_dimensions[i] = window_util::StridedBound(
211         padded_dilated_base, dilated_window, dim.stride());
212     output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
213   }
214 
215   return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
216                                        output_is_dynamic);
217 }
218 
MaybeUpcast(PrimitiveType from_type,absl::optional<PrimitiveType> preferred_element_type)219 StatusOr<PrimitiveType> MaybeUpcast(
220     PrimitiveType from_type,
221     absl::optional<PrimitiveType> preferred_element_type) {
222   if (!preferred_element_type.has_value() ||
223       *preferred_element_type == from_type) {
224     return from_type;
225   }
226   if (!primitive_util::IsFloatingPointType(from_type) &&
227       primitive_util::BitWidth(*preferred_element_type) <
228           primitive_util::BitWidth(from_type)) {
229     return InvalidArgument(
230         "`preferred_element_type` must not be narrower than the original "
231         "type.");
232   }
233   return *preferred_element_type;
234 }
235 
236 }  // namespace
237 
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)238 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
239     HloOpcode opcode, const HloInstruction* operand) {
240   return InferUnaryOpShape(opcode, operand->shape());
241 }
242 
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)243 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
244     HloOpcode opcode, const Shape& shape) {
245   // There is no copy operation at the proto level, so handle copy explicitly.
246   // A domain shape is the same as the input one.
247   if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
248     return shape;
249   }
250 
251   TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
252 
253   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
254   switch (opcode) {
255     case HloOpcode::kFloor:
256     case HloOpcode::kCeil:
257     case HloOpcode::kRoundNearestAfz:
258       if (!ShapeUtil::ElementIsFloating(shape)) {
259         return InvalidArgument(
260             "Expected element type in shape to be floating for %s operation; "
261             "got %s.",
262             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
263       }
264       return shape;
265     case HloOpcode::kCos:
266     case HloOpcode::kSin:
267     case HloOpcode::kExp:
268     case HloOpcode::kExpm1:
269     case HloOpcode::kLog:
270     case HloOpcode::kLog1p:
271     case HloOpcode::kLogistic:
272     case HloOpcode::kRsqrt:
273     case HloOpcode::kSqrt:
274     case HloOpcode::kCbrt:
275     case HloOpcode::kTanh:
276       if (!ShapeUtil::ElementIsFloating(shape) &&
277           !ShapeUtil::ElementIsComplex(shape)) {
278         return InvalidArgument(
279             "Expected element type in shape to be floating or complex for %s "
280             "operation; got %s.",
281             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
282       }
283       return shape;
284     case HloOpcode::kReal:
285     case HloOpcode::kImag:
286       if (ShapeUtil::ElementIsComplex(shape)) {
287         return ShapeUtil::ComplexComponentShape(shape);
288       } else if (ShapeUtil::ElementIsFloating(shape)) {
289         return shape;
290       } else {
291         return InvalidArgument(
292             "Expected element type in shape to be floating or complex for "
293             "%s operation; got %s.",
294             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
295       }
296     case HloOpcode::kAbs:
297       if (ShapeUtil::ElementIsComplex(shape)) {
298         return ShapeUtil::ChangeElementType(
299             shape, primitive_util::ComplexComponentType(shape.element_type()));
300       } else if (ShapeUtil::ElementIsSigned(shape)) {
301         return shape;
302       } else {
303         return InvalidArgument(
304             "Expected element type in shape to be floating or complex for "
305             "%s operation; got %s.",
306             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
307       }
308     case HloOpcode::kClz:
309       if (!ShapeUtil::ElementIsIntegral(shape)) {
310         return InvalidArgument(
311             "Expected an integral element type in argument to Clz "
312             "operation; got %s.",
313             PrimitiveType_Name(shape.element_type()));
314       }
315       return shape;
316     case HloOpcode::kNegate:
317       if (!ShapeUtil::ElementIsIntegral(shape) &&
318           !ShapeUtil::ElementIsFloating(shape) &&
319           !ShapeUtil::ElementIsComplex(shape)) {
320         return InvalidArgument(
321             "Expected element type in shape to be integral, floating or "
322             "complex for %s operation; got %s.",
323             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
324       }
325       return shape;
326     case HloOpcode::kPopulationCount:
327       if (!ShapeUtil::ElementIsIntegral(shape)) {
328         return InvalidArgument(
329             "Expected an integral element type in argument to PopulationCount "
330             "operation; got %s.",
331             PrimitiveType_Name(shape.element_type()));
332       }
333       return shape;
334     case HloOpcode::kSign:
335       if (!ShapeUtil::ElementIsSigned(shape) &&
336           !ShapeUtil::ElementIsComplex(shape)) {
337         return InvalidArgument(
338             "Expected element type in shape to be signed or complex for "
339             "%s operation; got %s.",
340             HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
341       }
342       return shape;
343 
344     case HloOpcode::kNot:
345       if (shape.element_type() != PRED &&
346           !primitive_util::IsIntegralType(shape.element_type())) {
347         return InvalidArgument(
348             "Expected pred or an integral element type in argument to Not "
349             "operation; got %s.",
350             PrimitiveType_Name(shape.element_type()));
351       }
352       return shape;
353 
354     case HloOpcode::kIsFinite:
355       if (!ShapeUtil::ElementIsFloating(shape)) {
356         return InvalidArgument(
357             "Expected element type in shape to be floating "
358             "point for IsFinite "
359             "operation; got %s.",
360             PrimitiveType_Name(shape.element_type()));
361       }
362       return ShapeUtil::ChangeElementType(shape, PRED);
363 
364     default:
365       return InvalidArgument(
366           "Unknown operation for unary shape inference: \"%s\".",
367           HloOpcodeString(opcode));
368   }
369 }
370 
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64_t dimension)371 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
372     absl::Span<const Shape* const> arg_shapes, const int64_t dimension) {
373   if (arg_shapes.empty()) {
374     return InvalidArgument("Concatenate expects at least one argument.");
375   }
376   if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
377     return InvalidArgument("Concatenate dimension out of bounds: %d.",
378                            dimension);
379   }
380   const Shape* arg_shape = nullptr;
381   PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
382   for (const Shape* shape : arg_shapes) {
383     TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
384     if (!arg_shape) {
385       arg_shape = shape;
386       element_type = arg_shape->element_type();
387       continue;
388     }
389     if (arg_shape->rank() != shape->rank()) {
390       return InvalidArgument(
391           "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
392           "(%s).",
393           arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
394           ShapeUtil::HumanString(*shape));
395     }
396     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
397       return InvalidArgument(
398           "Cannot concatenate arrays with different element types: %s vs %s.",
399           PrimitiveType_Name(arg_shape->element_type()),
400           PrimitiveType_Name(shape->element_type()));
401     }
402     for (int64_t dimension_number = 0; dimension_number < arg_shape->rank();
403          ++dimension_number) {
404       if (arg_shape->dimensions(dimension_number) !=
405           shape->dimensions(dimension_number)) {
406         if (dimension_number == dimension) {
407           continue;  // It's okay to differ in the dimension we're
408                      // concatenating.
409         }
410         return InvalidArgument(
411             "Cannot concatenate arrays that differ in dimensions other than "
412             "the one being concatenated (the other array dimensions must be "
413             "the same): %s vs %s in dimension %d.",
414             ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
415             dimension);
416       }
417     }
418     element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
419   }
420 
421   std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
422                                     arg_shape->dimensions().end());
423   for (size_t i = 1; i < arg_shapes.size(); ++i) {
424     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
425   }
426 
427   Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
428 
429   // Set dynamic dimensions if any input has dynamic dimension.
430   for (const Shape* shape : arg_shapes) {
431     for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
432       if (shape->is_dynamic_dimension(i)) {
433         result.set_dynamic_dimension(i, true);
434       }
435     }
436   }
437   return result;
438 }
439 
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)440 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
441     const Shape& operand_shape, PrimitiveType new_element_type) {
442   auto old_element_type = operand_shape.element_type();
443   if (primitive_util::IsComplexType(old_element_type) &&
444       !primitive_util::IsComplexType(new_element_type)) {
445     return Unimplemented(
446         "Conversion from complex to real type %s => %s is not implemented.",
447         ShapeUtil::HumanString(operand_shape),
448         PrimitiveType_Name(new_element_type));
449   }
450   if (!operand_shape.IsArray() ||
451       !primitive_util::IsArrayType(new_element_type)) {
452     // Note: we may want to support tuple conversions via this operation in the
453     // future, by recursing into the tuple elements to check all sub-conversions
454     // are valid. For now we just reject them, though.
455     return InvalidArgument(
456         "Convert does not allow non-arrays, so cannot convert from %s to %s.",
457         ShapeUtil::HumanString(operand_shape),
458         PrimitiveType_Name(new_element_type));
459   }
460 
461   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
462 }
463 
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)464 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
465     const Shape& operand_shape, PrimitiveType new_element_type) {
466   auto old_element_type = operand_shape.element_type();
467   if (primitive_util::IsComplexType(old_element_type) !=
468       primitive_util::IsComplexType(new_element_type)) {
469     return InvalidArgument("Conversion from complex to real type %s => %s.",
470                            ShapeUtil::HumanString(operand_shape),
471                            PrimitiveType_Name(new_element_type));
472   }
473   if (!operand_shape.IsArray() ||
474       !primitive_util::IsArrayType(new_element_type)) {
475     // Note: we may want to support tuple conversions via this operation in the
476     // future, by recursing into the tuple elements to check all sub-conversions
477     // are valid. For now we just reject them, though.
478     return InvalidArgument(
479         "Cannot convert from or to tuple type; requested conversion: %s => %s.",
480         ShapeUtil::HumanString(operand_shape),
481         PrimitiveType_Name(new_element_type));
482   }
483   if (primitive_util::BitWidth(old_element_type) !=
484       primitive_util::BitWidth(new_element_type)) {
485     return InvalidArgument(
486         "Cannot bitcast types with different bit-widths: %s => %s.",
487         PrimitiveType_Name(old_element_type),
488         PrimitiveType_Name(new_element_type));
489   }
490 
491   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
492 }
493 
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)494 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
495     const Shape& operand_shape, const int exponent_bits,
496     const int mantissa_bits) {
497   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
498     return InvalidArgument(
499         "Expected element type in shape to be floating point for "
500         "ReducePrecision operation; got %s.",
501         PrimitiveType_Name(operand_shape.element_type()));
502   }
503   if (exponent_bits < 1) {
504     // One exponent bit is necessary to distinguish 0 from infinity.  Having
505     // no exponent bits doesn't produce a sensible number, so we require at
506     // least one.
507     return InvalidArgument("Expected exponent_bits >= 1; got %d.",
508                            exponent_bits);
509   }
510   if (mantissa_bits < 0) {
511     // A number with no mantissa bits is still meaningful, however.
512     return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
513                            mantissa_bits);
514   }
515   return operand_shape;
516 }
517 
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)518 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
519     const Shape& operand_shape, const Shape& padding_value_shape,
520     const PaddingConfig& padding_config) {
521   if (!operand_shape.IsArray()) {
522     return InvalidArgument(
523         "Pad operation does not support tuple-shape operands.");
524   }
525   if (!ShapeUtil::IsScalar(padding_value_shape)) {
526     return InvalidArgument(
527         "Pad operation does not support non-scalar padding values.");
528   }
529   if (operand_shape.rank() != padding_config.dimensions_size()) {
530     return InvalidArgument(
531         "The rank of the operand and the padding configuration do not match: "
532         "%s vs %s.",
533         ShapeUtil::HumanString(operand_shape),
534         padding_config.ShortDebugString());
535   }
536   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
537                                                      padding_value_shape)) {
538     return InvalidArgument(
539         "The element types of the operands to Pad do not match.");
540   }
541   if (absl::c_any_of(padding_config.dimensions(),
542                      [](const PaddingConfig::PaddingConfigDimension& p) {
543                        return p.interior_padding() < 0;
544                      })) {
545     return InvalidArgument("Interior padding cannot be negative: %s",
546                            padding_config.ShortDebugString());
547   }
548 
549   if (!padding_value_shape.is_static()) {
550     return InvalidArgument("Dynamic padding value is not supported");
551   }
552 
553   std::vector<int64> dimensions(operand_shape.rank());
554   std::vector<bool> is_dynamic(operand_shape.rank());
555   for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
556     const auto& p = padding_config.dimensions(i);
557     dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
558                     p.edge_padding_high() +
559                     std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
560                         p.interior_padding();
561     if (dimensions[i] < 0) {
562       return InvalidArgument("Padding result in negative size for dimension %d",
563                              i);
564     }
565     is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
566   }
567 
568   return ShapeUtil::MakeShape(
569       ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
570       dimensions, is_dynamic);
571 }
572 
573 // Current DotDimensionNumbers Requirements:
574 //
575 // Contracting Dimensions:
576 // *) Same number of contracting dimensions on both lhs and rhs.
577 // *) Contracting dimension size must be the same on both lhs and rhs.
578 //
579 // Batch Dimensions:
580 // *) Same number of batch dimensions on both lhs and rhs.
581 // *) Same batch dimension sizes on both lhs and rhs.
582 //
583 
584 namespace {
585 
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)586 Status ValidateDotDimensionNumbers(
587     const Shape& lhs, const Shape& rhs,
588     const DotDimensionNumbers& dimension_numbers) {
589   // Check that dimension numbers are in range.
590   auto dims_in_range = [](const int64_t rank,
591                           absl::Span<const int64> contracting_dims,
592                           absl::Span<const int64> batch_dims) -> bool {
593     auto in_range = [&rank](int64_t i) -> bool { return 0 <= i && i < rank; };
594     return absl::c_all_of(contracting_dims, in_range) &&
595            absl::c_all_of(batch_dims, in_range);
596   };
597 
598   absl::Span<const int64> lhs_contracting_dimensions =
599       AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
600   absl::Span<const int64> rhs_contracting_dimensions =
601       AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
602   absl::Span<const int64> lhs_batch_dimensions =
603       AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
604   absl::Span<const int64> rhs_batch_dimensions =
605       AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
606 
607   if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
608                      lhs_batch_dimensions) ||
609       !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
610                      rhs_batch_dimensions)) {
611     return InvalidArgument("A dimension number is out of range in Dot: %s.",
612                            dimension_numbers.DebugString());
613   }
614 
615   // Check that dimension numbers are unique.
616   auto dims_unique = [](absl::Span<const int64> contracting_dims,
617                         absl::Span<const int64> batch_dims) -> bool {
618     absl::flat_hash_set<int64> dim_set;
619     auto is_unique = [&dim_set](int64_t i) -> bool {
620       return dim_set.insert(i).second;
621     };
622     return absl::c_all_of(contracting_dims, is_unique) &&
623            absl::c_all_of(batch_dims, is_unique);
624   };
625 
626   if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
627       !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
628     return InvalidArgument("A dimension number is not unique in Dot: %s.",
629                            dimension_numbers.DebugString());
630   }
631 
632   return Status::OK();
633 }
634 
635 }  // namespace
636 
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers,absl::optional<PrimitiveType> preferred_element_type)637 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
638     const Shape& lhs, const Shape& rhs,
639     const DotDimensionNumbers& dimension_numbers,
640     absl::optional<PrimitiveType> preferred_element_type) {
641   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
642   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
643 
644   auto fail = [lhs, rhs](const string& addendum) -> Status {
645     string message =
646         StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
647                   ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
648     if (!addendum.empty()) {
649       message += " " + addendum;
650     }
651     return InvalidArgument("%s", message);
652   };
653 
654   // Validate basic properties of dot dimension numbers.
655   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
656 
657   // Check that number of contracting dimensions match.
658   if (dimension_numbers.lhs_contracting_dimensions_size() !=
659       dimension_numbers.rhs_contracting_dimensions_size()) {
660     return fail(
661         "Must specify the same number of contracting dimensions for lhs and "
662         "rhs.");
663   }
664   // Check that contracting dimension sizes match.
665   for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
666        ++i) {
667     const int64_t lhs_contracting_dimension =
668         dimension_numbers.lhs_contracting_dimensions(i);
669     const int64_t rhs_contracting_dimension =
670         dimension_numbers.rhs_contracting_dimensions(i);
671     if (lhs.dimensions(lhs_contracting_dimension) !=
672         rhs.dimensions(rhs_contracting_dimension)) {
673       return fail("Contracting dimension sizes do not match.");
674     }
675   }
676 
677   // Check that number of batch dimensions match.
678   if (dimension_numbers.lhs_batch_dimensions_size() !=
679       dimension_numbers.rhs_batch_dimensions_size()) {
680     return fail("Must the same number of batch dimensions for lhs and rhs.");
681   }
682 
683   // Check that batch dimension numbers and sizes match.
684   for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
685     if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
686         rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
687       return fail("Batch dimension sizes must match for lhs/rhs.");
688     }
689   }
690 
691   // The ranks of lhs and rhs are decremented by 1 respectively due to the
692   // contraction, and added for the rank of the result. When an input tensor is
693   // a scalar, its contribution to the rank of the result is 0.
694   // Generate the result dimensions in order, rhs dimensions followed by lhs
695   // dimensions except the contracted and batch dimensions.
696   std::vector<int64> dimensions;
697   std::vector<bool> is_dynamic;
698   for (int64_t lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
699     dimensions.push_back(lhs.dimensions(lhs_dim));
700     is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
701   }
702   for (int64_t i = 0; i < lhs.rank(); i++) {
703     if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
704                                i) &&
705         !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
706       dimensions.push_back(lhs.dimensions(i));
707       is_dynamic.push_back(lhs.is_dynamic_dimension(i));
708     }
709   }
710   for (int64_t i = 0; i < rhs.rank(); i++) {
711     if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
712                                i) &&
713         !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
714       dimensions.push_back(rhs.dimensions(i));
715       is_dynamic.push_back(rhs.is_dynamic_dimension(i));
716     }
717   }
718   TF_ASSIGN_OR_RETURN(
719       PrimitiveType type,
720       MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
721                   preferred_element_type));
722   Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
723 
724   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
725   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
726   return result;
727 }
728 
729 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)730 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
731                                                        const Shape& lhs,
732                                                        const Shape& rhs) {
733   TF_RET_CHECK(lhs.rank() == rhs.rank());
734 
735   // The shapes have to be compatible. That is, if some dimension d has a
736   // different size in the two shapes, one of them has to be 1 (a "degenerate"
737   // dimension). In that case, the output shape has the non-1 dimension size
738   // from the lhs/rhs pair in every index.
739   std::vector<int64> output_dimensions(lhs.rank());
740   std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
741   for (int64_t i = 0; i < lhs.rank(); ++i) {
742     if (lhs.dimensions(i) == rhs.dimensions(i)) {
743       output_dimensions[i] = lhs.dimensions(i);
744     } else if (lhs.dimensions(i) == 1) {
745       output_dimensions[i] = rhs.dimensions(i);
746     } else if (rhs.dimensions(i) == 1) {
747       output_dimensions[i] = lhs.dimensions(i);
748     } else {
749       return InvalidArgument(
750           "Binary op %s with incompatible shapes: %s and %s.",
751           HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
752           ShapeUtil::HumanString(rhs));
753     }
754   }
755 
756   // Merge dynamic dimensions from two shapes.
757   for (int64_t i = 0; i < rhs.rank(); ++i) {
758     if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
759       output_dimensions_is_dynamic[i] = true;
760     }
761   }
762 
763   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
764                               output_dimensions, output_dimensions_is_dynamic);
765 }
766 
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)767 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
768     const Shape& smaller_shape, const Shape& larger_shape,
769     absl::Span<const int64> broadcast_dimensions) {
770   if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
771     // Reject "magic" inference for binops on different shapes, requiring
772     // the user to provide an explicit broadcast dimension in this case.
773     // See b/25177275 for more details.
774     return InvalidArgument("Automatic shape inference not supported: %s and %s",
775                            ShapeUtil::HumanString(smaller_shape),
776                            ShapeUtil::HumanString(larger_shape));
777   } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
778     return InvalidArgument(
779         "Size of broadcast_dimensions has to match lower-rank operand's "
780         "rank; "
781         " lower-rank operand's rank is %d, size of broadcast_dimensions is "
782         "%u.",
783         smaller_shape.rank(), broadcast_dimensions.size());
784   }
785 
786   // broadcast_dimensions is a sequence of dimensions; its length is equal to
787   // the rank of the lower-rank operand. The lower-rank operand's dimensions
788   // have to be compatible with the higher-rank operand's dimensions at indices
789   // specified by broadcast_dimensions. Here compatible means the dimension
790   // sizes are equal or in one of the shapes the dimension size is
791   // one. Examples:
792   //
793   // smaller_shape   larger_shape   broadcast_dimensions   output_shape
794   //   []              [2, 3]          {}                    [2, 3]
795   //   [3]             [4, 3]          {1}                   [4, 3]
796   //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
797   //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
798   //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
799   //
800   // The column output_shape may not be the final shape of the XLA
801   // operation. After the "InDim" broadcasting implemented in this function
802   // expands the rank, degenerate-dimension broadcasting (implemented in
803   // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
804   // up to match the dimension size of the other operand. For example, consider
805   // the row in the table above with a smaller_shape of [2, 1]. The shape
806   // returned by this function is [2, 3, 1] (output_shape) however, the result
807   // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
808   // broadcasting.
809   //
810   // Invalid broadcasts:
811   //
812   // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
813   // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
814   //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
815   //   in broadcast_dimensions.
816   //
817   // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
818   // Reason: Dimension one of larger_shape (size 3) is not compatible with
819   //   dimension zero of smaller_shape(size 2)
820 
821   // The output shape is initially the larger_shape. Sizes of dimensions
822   // specified in broadcast_dimensions are then changed to match the
823   // corresponding dimension size in smaller_shape.
824   Shape output_shape(larger_shape);
825   output_shape.set_element_type(
826       ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
827 
828   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
829     int64_t dimension_to_match = broadcast_dimensions.at(i);
830     if (dimension_to_match < 0) {
831       return InvalidArgument(
832           "Broadcast dimension number (%d) cannot be negative.",
833           dimension_to_match);
834     }
835     if (dimension_to_match >= larger_shape.dimensions_size()) {
836       return InvalidArgument(
837           "Broadcast dimension number (%d) too large; higher-rank "
838           "operand has rank %d.",
839           dimension_to_match, larger_shape.dimensions_size());
840     }
841     int64_t small_dimension_size = smaller_shape.dimensions(i);
842     int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match);
843     bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
844     bool large_is_dynamic =
845         larger_shape.is_dynamic_dimension(dimension_to_match);
846     // Dimension sizes must be compatible: match or be degenerate (degenerate
847     // case is handled by degenerate dimension broadcasting which occurs after
848     // InDim broadcasting).
849     if (small_dimension_size != large_dimension_size &&
850         small_dimension_size != 1 && large_dimension_size != 1) {
851       return InvalidArgument(
852           "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
853           small_dimension_size, large_dimension_size,
854           ShapeUtil::HumanString(smaller_shape),
855           ShapeUtil::HumanString(larger_shape));
856     }
857     if (small_is_dynamic != large_is_dynamic) {
858       if (small_dimension_size == large_dimension_size ||
859           (small_dimension_size == 1 && !small_is_dynamic) ||
860           (large_dimension_size == 1 && !large_is_dynamic)) {
861         // Do nothing. It's OK when the size-1 dimension is not static.
862       } else {
863         return InvalidArgument(
864             "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
865             ShapeUtil::HumanString(smaller_shape),
866             ShapeUtil::HumanString(larger_shape));
867       }
868     }
869     // Make sure the broadcast dimensions are listed in a strictly increasing
870     // order.
871     if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
872       return InvalidArgument(
873           "Broadcast dimensions order is wrong: %d comes after %d.",
874           dimension_to_match, broadcast_dimensions.at(i - 1));
875     }
876 
877     output_shape.set_dimensions(dimension_to_match, small_dimension_size);
878     output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
879   }
880 
881   return output_shape;
882 }
883 
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)884 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
885     HloOpcode operation, const Shape& lhs, const Shape& rhs,
886     absl::Span<const int64> broadcast_dimensions) {
887   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
888   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
889 
890   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
891     return InvalidArgument(
892         "Binary op %s with different element types: %s and %s.",
893         HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
894         ShapeUtil::HumanString(rhs));
895   }
896 
897   if (lhs.rank() == rhs.rank()) {
898     std::vector<int64> identity_dims(lhs.rank());
899     std::iota(identity_dims.begin(), identity_dims.end(), 0);
900     if (!broadcast_dimensions.empty() &&
901         broadcast_dimensions != identity_dims) {
902       return InvalidArgument(
903           "Broadcast dimensions field must either be not set or be the "
904           "identity on binary operations with operands of the same rank.");
905     }
906   }
907 
908   if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
909     // If the shapes are the same other than layout, the output shape is the
910     // same (elementwise op).
911     Shape result = ShapeUtil::ChangeElementType(
912         lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
913 
914     for (int64_t i = 0; i < rhs.rank(); ++i) {
915       if (rhs.is_dynamic_dimension(i)) {
916         result.set_dynamic_dimension(i, true);
917       }
918     }
919 
920     return result;
921 
922   } else if (lhs.rank() == rhs.rank()) {
923     return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
924   } else {
925     // Ranks do not match, so perform InDim broadcasting using
926     // broadcast_dimensions. Scalar broadcasting is a special case of this.
927     const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
928     const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
929 
930     // After InDim broadcasting, perform degenerate dimensions broadcasting.
931     TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
932                         InferInDimBroadcastShape(smaller_shape, larger_shape,
933                                                  broadcast_dimensions));
934 
935     return InferDegenerateDimensionBroadcastShape(
936         operation, indim_broadcast_shape, larger_shape);
937   }
938 }
939 
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)940 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
941     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
942   return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
943                             /*broadcast_dimensions=*/{});
944 }
945 
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)946 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
947     HloOpcode opcode, const Shape& lhs, const Shape& rhs,
948     absl::Span<const int64> broadcast_dimensions) {
949   VLOG(2) << StrFormat(
950       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
951       HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
952       ShapeUtil::HumanStringWithLayout(rhs),
953       StrJoin(broadcast_dimensions, ", "));
954 
955   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
956   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
957 
958   TF_RETURN_IF_ERROR(ExpectArray(
959       lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
960   TF_RETURN_IF_ERROR(ExpectArray(
961       rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
962   switch (opcode) {
963     case HloOpcode::kAdd:
964     case HloOpcode::kMaximum:
965     case HloOpcode::kMinimum:
966     case HloOpcode::kMultiply:
967       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
968                                            broadcast_dimensions);
969 
970     case HloOpcode::kSubtract:
971     case HloOpcode::kAtan2:
972     case HloOpcode::kPower:
973     case HloOpcode::kDivide:
974     case HloOpcode::kRemainder:
975     case HloOpcode::kShiftLeft:
976     case HloOpcode::kShiftRightArithmetic:
977     case HloOpcode::kShiftRightLogical:
978       if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
979         return InvalidArgument(
980             "Expected element type in shape to be arithmetic type for "
981             "operation %s; got PRED.",
982             HloOpcodeString(opcode));
983       }
984       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
985                                            broadcast_dimensions);
986 
987     case HloOpcode::kComplex: {
988       if (!ShapeUtil::ElementIsFloating(lhs)) {
989         return InvalidArgument(
990             "Expected element type in shape to be floating for complex compose "
991             "operation; got %s.",
992             PrimitiveType_Name(lhs.element_type()));
993       }
994       TF_ASSIGN_OR_RETURN(const Shape& shape,
995                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
996                                                         broadcast_dimensions));
997       if (lhs.element_type() == F32 && rhs.element_type() == F32) {
998         return ShapeUtil::ChangeElementType(shape, C64);
999       } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
1000         return ShapeUtil::ChangeElementType(shape, C128);
1001       } else {
1002         return Unimplemented("Complex component type is not implemented.");
1003       }
1004     }
1005     case HloOpcode::kAnd:
1006     case HloOpcode::kOr:
1007     case HloOpcode::kXor:
1008       if (lhs.element_type() != PRED &&
1009           !primitive_util::IsIntegralType(lhs.element_type())) {
1010         return InvalidArgument(
1011             "Expected pred or integral type in argument to and/or operation; "
1012             "got %s.",
1013             PrimitiveType_Name(lhs.element_type()));
1014       }
1015       return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1016                                            broadcast_dimensions);
1017     case HloOpcode::kCompare: {
1018       TF_ASSIGN_OR_RETURN(const Shape& shape,
1019                           InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1020                                                         broadcast_dimensions));
1021       return ShapeUtil::ChangeElementType(shape, PRED);
1022     }
1023     default:
1024       return Unimplemented(
1025           "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1026           HloOpcodeString(opcode), lhs.ShortDebugString(),
1027           rhs.ShortDebugString());
1028   }
1029 }
1030 
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1031 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1032     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1033     const HloInstruction* ehs) {
1034   return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1035 }
1036 
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1037 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1038     HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1039   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1040   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1041   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1042   switch (opcode) {
1043     case HloOpcode::kClamp:
1044       return InferClampShape(lhs, rhs, ehs);
1045     case HloOpcode::kSelect:
1046       return InferSelectShape(lhs, rhs, ehs);
1047     case HloOpcode::kTupleSelect:
1048       return InferTupleSelectShape(lhs, rhs, ehs);
1049     default:
1050       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1051   }
1052 }
1053 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1054 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1055     HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1056   std::vector<const Shape*> operand_shapes;
1057   operand_shapes.reserve(operands.size());
1058   for (const HloInstruction* operand : operands) {
1059     operand_shapes.push_back(&operand->shape());
1060   }
1061   return InferVariadicOpShape(opcode, operand_shapes);
1062 }
1063 
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1064 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1065     HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1066   for (const Shape* shape : operand_shapes) {
1067     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1068   }
1069   switch (opcode) {
1070     case HloOpcode::kTuple: {
1071       Shape result = ShapeUtil::MakeTupleShape({});
1072       result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1073       for (const Shape* shape : operand_shapes) {
1074         ShapeUtil::AppendShapeToTuple(*shape, &result);
1075       }
1076       return result;
1077     }
1078     case HloOpcode::kSort: {
1079       if (operand_shapes.size() == 1) {
1080         return *operand_shapes[0];
1081       } else {
1082         for (int64_t operand = 1; operand < operand_shapes.size(); ++operand) {
1083           if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1084                                          *operand_shapes[operand])) {
1085             return InvalidArgument(
1086                 "Sort keys and values dimensions must match. "
1087                 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1088                 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1089                 ShapeUtil::HumanString(*operand_shapes[operand]));
1090           }
1091         }
1092         std::vector<Shape> operand_shape_values;
1093         for (const Shape* operand_shape : operand_shapes) {
1094           operand_shape_values.push_back(*operand_shape);
1095         }
1096         return ShapeUtil::MakeTupleShape(operand_shape_values);
1097       }
1098       return InvalidArgument("Unexpected number of operands for sort");
1099     }
1100     default:
1101       return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1102   }
1103 }
1104 
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1105 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1106     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1107     absl::Span<const int64> dimensions) {
1108   if (arg_shapes.empty()) {
1109     return InvalidArgument("Map expects at least one argument.");
1110   }
1111 
1112   // All arguments must have the same shape.
1113   const Shape* arg_shape = arg_shapes[0];
1114   for (size_t i = 1; i < arg_shapes.size(); ++i) {
1115     TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1116 
1117     if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1118       continue;
1119     }
1120     if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1121                                                       *arg_shape)) {
1122       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1123         continue;
1124       }
1125       if (ShapeUtil::IsScalar(*arg_shape)) {
1126         arg_shape = arg_shapes[i];
1127         continue;
1128       }
1129     }
1130 
1131     std::vector<string> pieces;
1132     for (const Shape* shape : arg_shapes) {
1133       pieces.push_back(ShapeUtil::HumanString(*shape));
1134     }
1135     return InvalidArgument(
1136         "Map operation requires all operands to have the same shape; got: "
1137         "%s.",
1138         StrJoin(pieces, ", "));
1139   }
1140 
1141   // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1142   // only support mapping across all dimensions: i.e. scalar map functions).
1143   if (dimensions.size() != arg_shape->dimensions_size()) {
1144     return InvalidArgument(
1145         "Map applied to a subset of dimensions currently not supported: "
1146         "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1147         arg_shape->dimensions_size(), dimensions.size());
1148   }
1149 
1150   // Check that requested map dimensions numbers are monotonically increasing.
1151   for (int i = 0; i < dimensions.size(); ++i) {
1152     if (dimensions[i] != i) {
1153       return InvalidArgument(
1154           "Map requires monotonically increasing dimension numbers; got: %s.",
1155           StrJoin(dimensions, ", "));
1156     }
1157   }
1158 
1159   // The applied function's arity equals the number of arguments.
1160   if (arg_shapes.size() != to_apply.parameters_size()) {
1161     return InvalidArgument(
1162         "Map applied function arity must match number of arguments; got: "
1163         "arity: %d, arguments: %u.",
1164         to_apply.parameters_size(), arg_shapes.size());
1165   }
1166 
1167   // The parameters should all be scalars, and the output too.
1168   const Shape& output_shape = to_apply.result();
1169   if (!ShapeUtil::IsScalar(output_shape)) {
1170     return InvalidArgument(
1171         "Mapped computation's result has to be a scalar; got: %s.",
1172         ShapeUtil::HumanString(output_shape));
1173   }
1174 
1175   for (int i = 0; i < to_apply.parameters_size(); ++i) {
1176     const Shape& parameter_shape = to_apply.parameters(i);
1177 
1178     if (!ShapeUtil::IsScalar(parameter_shape)) {
1179       return InvalidArgument(
1180           "Mapped computation's parameter has to be a scalar; "
1181           "got parameter %d shape: %s.",
1182           i, ShapeUtil::HumanString(parameter_shape));
1183     }
1184 
1185     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1186                                                        *arg_shape)) {
1187       return InvalidArgument(
1188           "Mapped computation's parameter type has to match argument element "
1189           "type; got parameter %d shape: %s, argument shape: %s.",
1190           i, ShapeUtil::HumanString(parameter_shape),
1191           ShapeUtil::HumanString(*arg_shape));
1192     }
1193   }
1194 
1195   return ShapeUtil::MakeShape(output_shape.element_type(),
1196                               AsInt64Slice(arg_shape->dimensions()));
1197 }
1198 
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64_t feature_index)1199 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1200     const Shape& operand_shape, const Shape& scale_shape,
1201     const Shape& offset_shape, int64_t feature_index) {
1202   TF_RETURN_IF_ERROR(
1203       ExpectArray(operand_shape, "operand of batch norm training"));
1204   TF_RETURN_IF_ERROR(
1205       ExpectArray(offset_shape, "offset input of batch norm training"));
1206   TF_RETURN_IF_ERROR(
1207       ExpectArray(scale_shape, "scale input of batch norm training"));
1208 
1209   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1210                Status::OK());
1211   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1212                Status::OK());
1213   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1214                Status::OK());
1215 
1216   if (feature_index >= operand_shape.rank()) {
1217     return InvalidArgument(
1218         "Expected feature_index of batch-norm-training to be "
1219         "smaller than the rank of operand_shape; "
1220         "got feature_index %d, and rank %d.",
1221         feature_index, operand_shape.rank());
1222   }
1223 
1224   if (feature_index < 0) {
1225     return InvalidArgument(
1226         "Expected feature_index of batch-norm-training to "
1227         "be a non-negative number, got %d.",
1228         feature_index);
1229   }
1230 
1231   if (operand_shape.rank() < 1) {
1232     return InvalidArgument(
1233         "Expected the rank of operand to "
1234         "batch-norm-training to be at least 1; got %d.",
1235         operand_shape.rank());
1236   }
1237 
1238   if (offset_shape.rank() != 1) {
1239     return InvalidArgument(
1240         "Offset input of batch-norm-training must have"
1241         " rank 1, but has rank %d.",
1242         offset_shape.rank());
1243   }
1244 
1245   if (scale_shape.rank() != 1) {
1246     return InvalidArgument(
1247         "Scale input of batch-norm-training must have"
1248         " rank 1, but has rank %d.",
1249         scale_shape.rank());
1250   }
1251 
1252   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1253     return InvalidArgument(
1254         "The operand to batch-norm-training must have a floating point "
1255         "element type, but the shape is %s.",
1256         PrimitiveType_Name(operand_shape.element_type()));
1257   }
1258 
1259   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1260                                                      operand_shape)) {
1261     return InvalidArgument(
1262         "The inputs should have the same element type for batch-norm-training, "
1263         "but the shape of offset factor is %s "
1264         "and the shape of operand is %s.",
1265         PrimitiveType_Name(offset_shape.element_type()),
1266         PrimitiveType_Name(operand_shape.element_type()));
1267   }
1268 
1269   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1270                                                      operand_shape)) {
1271     return InvalidArgument(
1272         "The inputs should have the same element type for batch-norm-training, "
1273         "but the shape of scale factor is %s "
1274         "and the shape of operand is %s.",
1275         PrimitiveType_Name(scale_shape.element_type()),
1276         PrimitiveType_Name(operand_shape.element_type()));
1277   }
1278 
1279   const int64_t feature_count = operand_shape.dimensions(feature_index);
1280   Shape output_shape_for_mean_and_var =
1281       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1282 
1283   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1284     return InvalidArgument(
1285         "The size of offset factor should be the same as feature count,"
1286         "but the size of offset factor is %d "
1287         "and the feature count is %d.",
1288         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1289   }
1290 
1291   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1292     return InvalidArgument(
1293         "The size of scale factor should be the same as feature count,"
1294         "but the size of scale factor is %d "
1295         "and the feature count is %d.",
1296         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1297   }
1298 
1299   return ShapeUtil::MakeTupleShape({operand_shape,
1300                                     output_shape_for_mean_and_var,
1301                                     output_shape_for_mean_and_var});
1302 }
1303 
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64_t feature_index)1304 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1305     const Shape& operand_shape, const Shape& scale_shape,
1306     const Shape& offset_shape, const Shape& mean_shape,
1307     const Shape& variance_shape, int64_t feature_index) {
1308   TF_RETURN_IF_ERROR(
1309       ExpectArray(operand_shape, "operand of batch norm inference"));
1310   TF_RETURN_IF_ERROR(
1311       ExpectArray(offset_shape, "offset input of batch norm inference"));
1312   TF_RETURN_IF_ERROR(
1313       ExpectArray(scale_shape, "scale input of batch norm inference"));
1314 
1315   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1316   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1317   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1318   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1319   TF_RETURN_IF_ERROR(
1320       ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1321 
1322   if (feature_index >= operand_shape.rank()) {
1323     return InvalidArgument(
1324         "Expected feature_index of batch-norm-inference to be "
1325         "smaller than the rank of operand_shape; "
1326         "got feature_index %d, and rank %d.",
1327         feature_index, operand_shape.rank());
1328   }
1329 
1330   if (feature_index < 0) {
1331     return InvalidArgument(
1332         "Expected feature_index of batch-norm-inference to "
1333         "be a non-negative number, got %d.",
1334         feature_index);
1335   }
1336 
1337   if (operand_shape.rank() < 1) {
1338     return InvalidArgument(
1339         "Expected the rank of operand to "
1340         "batch-norm-inference to be at least 1; got %d.",
1341         operand_shape.rank());
1342   }
1343 
1344   if (offset_shape.rank() != 1) {
1345     return InvalidArgument(
1346         "Offset input of batch-norm-inference must have"
1347         " rank 1, but has rank %d.",
1348         offset_shape.rank());
1349   }
1350 
1351   if (scale_shape.rank() != 1) {
1352     return InvalidArgument(
1353         "Scale input of batch-norm-inference must have"
1354         " rank 1, but has rank %d.",
1355         scale_shape.rank());
1356   }
1357 
1358   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1359     return InvalidArgument(
1360         "The operand to batch-norm-inference must have a floating point "
1361         "element type, but the shape is %s.",
1362         PrimitiveType_Name(operand_shape.element_type()));
1363   }
1364 
1365   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1366                                                      operand_shape)) {
1367     return InvalidArgument(
1368         "The inputs should have the same element type for "
1369         "batch-norm-inference, "
1370         "but the shape of offset factor is %s "
1371         "and the shape of operand is %s.",
1372         PrimitiveType_Name(offset_shape.element_type()),
1373         PrimitiveType_Name(operand_shape.element_type()));
1374   }
1375 
1376   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1377                                                      operand_shape)) {
1378     return InvalidArgument(
1379         "The inputs should have the same element type for "
1380         "batch-norm-inference, "
1381         "but the shape of scale factor is %s "
1382         "and the shape of operand is %s.",
1383         PrimitiveType_Name(scale_shape.element_type()),
1384         PrimitiveType_Name(operand_shape.element_type()));
1385   }
1386 
1387   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1388                                                      operand_shape)) {
1389     return InvalidArgument(
1390         "The inputs should have the same element type for "
1391         "batch-norm-inference, "
1392         "but the shape of mean is %s "
1393         "and the shape of operand is %s.",
1394         PrimitiveType_Name(mean_shape.element_type()),
1395         PrimitiveType_Name(operand_shape.element_type()));
1396   }
1397 
1398   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1399                                                      operand_shape)) {
1400     return InvalidArgument(
1401         "The inputs should have the same element type for "
1402         "batch-norm-inference, "
1403         "but the shape of variance is %s "
1404         "and the shape of operand is %s.",
1405         PrimitiveType_Name(mean_shape.element_type()),
1406         PrimitiveType_Name(variance_shape.element_type()));
1407   }
1408 
1409   const int64_t feature_count = operand_shape.dimensions(feature_index);
1410   Shape output_shape_for_mean_and_var =
1411       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1412 
1413   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1414     return InvalidArgument(
1415         "The size of offset factor should be the same as feature count,"
1416         "but the size of offset factor is %d "
1417         "and the feature count is %d.",
1418         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1419   }
1420 
1421   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1422     return InvalidArgument(
1423         "The size of scale factor should be the same as feature count,"
1424         "but the size of scale factor is %d "
1425         "and the feature count is %d.",
1426         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1427   }
1428 
1429   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1430     return InvalidArgument(
1431         "The size of mean should be the same as feature count,"
1432         "but the size of mean is %d "
1433         "and the feature count is %d.",
1434         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1435   }
1436 
1437   if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1438     return InvalidArgument(
1439         "The size of variance should be the same as feature count,"
1440         "but the size of variance is %d "
1441         "and the feature count is %d.",
1442         ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1443   }
1444 
1445   return operand_shape;
1446 }
1447 
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64_t feature_index)1448 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1449     const Shape& operand_shape, const Shape& scale_shape,
1450     const Shape& mean_shape, const Shape& var_shape,
1451     const Shape& output_grad_shape, int64_t feature_index) {
1452   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1453   TF_RETURN_IF_ERROR(
1454       ExpectArray(scale_shape, "scale input of batch norm grad"));
1455   TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1456   TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1457   TF_RETURN_IF_ERROR(
1458       ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1459 
1460   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1461   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1462   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1463   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1464   TF_RETURN_IF_ERROR(
1465       ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1466 
1467   if (feature_index >= operand_shape.rank()) {
1468     return InvalidArgument(
1469         "Expected feature_index of batch-norm-grad to be "
1470         "smaller than the rank of operand_shape; "
1471         "got feature_index %d, and rank %d.",
1472         feature_index, operand_shape.rank());
1473   }
1474 
1475   if (operand_shape.rank() != output_grad_shape.rank()) {
1476     return InvalidArgument(
1477         "Expected operand_shape of batch-norm-grad to have the same rank as"
1478         " output_grad_shape; got rank(oprand_shape) %d, and"
1479         " rank(output_grad_shape) %d.",
1480         operand_shape.rank(), output_grad_shape.rank());
1481   }
1482 
1483   if (mean_shape.rank() != 1) {
1484     return InvalidArgument(
1485         "Mean input of batch-norm-grad must have"
1486         " rank 1, but has rank %d.",
1487         mean_shape.rank());
1488   }
1489 
1490   if (scale_shape.rank() != 1) {
1491     return InvalidArgument(
1492         "Scale input of batch-norm-grad must have"
1493         " rank 1, but has rank %d.",
1494         scale_shape.rank());
1495   }
1496 
1497   if (var_shape.rank() != 1) {
1498     return InvalidArgument(
1499         "Var input of batch-norm-grad must have"
1500         " rank 1, but has rank %d.",
1501         var_shape.rank());
1502   }
1503 
1504   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1505     return InvalidArgument(
1506         "The operand to batch-norm-grad must have a floating point "
1507         "element type, but the shape is %s.",
1508         PrimitiveType_Name(operand_shape.element_type()));
1509   }
1510 
1511   if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1512     return InvalidArgument(
1513         "The output_grad to batch-norm-grad must have a floating point "
1514         "element type, but the shape is %s.",
1515         PrimitiveType_Name(output_grad_shape.element_type()));
1516   }
1517 
1518   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1519                                                      operand_shape)) {
1520     return InvalidArgument(
1521         "The inputs should have the same element type for batch-norm-grad, "
1522         "but the element type of output_grad is %s "
1523         "and the element type of operand is %s.",
1524         PrimitiveType_Name(output_grad_shape.element_type()),
1525         PrimitiveType_Name(operand_shape.element_type()));
1526   }
1527 
1528   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1529                                                      operand_shape)) {
1530     return InvalidArgument(
1531         "The inputs should have the same element type for batch-norm-grad, "
1532         "but the element type of scale factor is %s "
1533         "and the element type of operand is %s.",
1534         PrimitiveType_Name(scale_shape.element_type()),
1535         PrimitiveType_Name(operand_shape.element_type()));
1536   }
1537 
1538   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1539                                                      operand_shape)) {
1540     return InvalidArgument(
1541         "The inputs should have the same element type for batch-norm-grad, "
1542         "but the element type of mean is %s "
1543         "and the element type of operand is %s.",
1544         PrimitiveType_Name(mean_shape.element_type()),
1545         PrimitiveType_Name(operand_shape.element_type()));
1546   }
1547 
1548   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1549                                                      operand_shape)) {
1550     return InvalidArgument(
1551         "The inputs should have the same element type for batch-norm-grad, "
1552         "but the element type of mean is %s "
1553         "and the element type of operand is %s.",
1554         PrimitiveType_Name(mean_shape.element_type()),
1555         PrimitiveType_Name(operand_shape.element_type()));
1556   }
1557 
1558   const int64_t feature_count = operand_shape.dimensions(feature_index);
1559 
1560   Shape feature_shape =
1561       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1562 
1563   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1564     return InvalidArgument(
1565         "The size of mean should be the same as feature count,"
1566         "but the size of offset factor is %d "
1567         "and the feature count is %d.",
1568         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1569   }
1570 
1571   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1572     return InvalidArgument(
1573         "The size of scale factor should be the same as feature count,"
1574         "but the size of scale factor is %d "
1575         "and the feature count is %d.",
1576         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1577   }
1578 
1579   if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1580     return InvalidArgument(
1581         "The size of variance should be the same as feature count,"
1582         "but the size of variance is %d "
1583         "and the feature count is %d.",
1584         ShapeUtil::GetDimension(var_shape, 0), feature_count);
1585   }
1586 
1587   // Verify operand_shape and output_grad_shape have same bounds.
1588   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
1589     if (ShapeUtil::GetDimension(operand_shape, i) !=
1590         ShapeUtil::GetDimension(output_grad_shape, i)) {
1591       return InvalidArgument(
1592           "The bounds of operand shape should be the same as output_grad's,"
1593           "but the bound of operand_shape at dimension %d is %d "
1594           "and the bound of output_grad_shape is %d.",
1595           i, ShapeUtil::GetDimension(operand_shape, i),
1596           ShapeUtil::GetDimension(output_grad_shape, i));
1597     }
1598   }
1599 
1600   return ShapeUtil::MakeTupleShape(
1601       {operand_shape, feature_shape, feature_shape});
1602 }
1603 
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums,absl::optional<PrimitiveType> preferred_element_type)1604 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1605     const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
1606     int64_t batch_group_count, const Window& window,
1607     const ConvolutionDimensionNumbers& dnums,
1608     absl::optional<PrimitiveType> preferred_element_type) {
1609   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1610   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1611 
1612   if (feature_group_count <= 0) {
1613     return InvalidArgument(
1614         "feature_group_count must be a positive number, got %d",
1615         feature_group_count);
1616   }
1617 
1618   if (batch_group_count <= 0) {
1619     return InvalidArgument(
1620         "batch_group_count must be a positive number, got %d",
1621         batch_group_count);
1622   }
1623 
1624   if (batch_group_count > 1 && feature_group_count > 1) {
1625     return InvalidArgument(
1626         "both batch_group_count %d and feature_group_count %d cannot be "
1627         "greater than 1",
1628         batch_group_count, feature_group_count);
1629   }
1630 
1631   if (dnums.input_spatial_dimensions_size() !=
1632       dnums.kernel_spatial_dimensions_size()) {
1633     return InvalidArgument(
1634         "Both arguments to convolution must have same number of dimensions.\n"
1635         "Numbers: %s",
1636         dnums.DebugString());
1637   }
1638 
1639   if (dnums.input_spatial_dimensions_size() !=
1640       dnums.output_spatial_dimensions_size()) {
1641     return InvalidArgument(
1642         "Both input and output of convolution must have same number of "
1643         "dimensions.\nNumbers: %s",
1644         dnums.DebugString());
1645   }
1646 
1647   const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1648   if (window.dimensions_size() != num_spatial_dims) {
1649     return InvalidArgument(
1650         "Window must have same number of dimensions as dimension numbers.\n"
1651         "Window: %s\nDimension numbers: %s.",
1652         window.DebugString(), dnums.DebugString());
1653   }
1654 
1655   const int num_dims = num_spatial_dims + 2;
1656   if (lhs.rank() != num_dims) {
1657     return InvalidArgument(
1658         "The LHS argument to a convolution should have rank %d; lhs: %s.",
1659         num_dims, ShapeUtil::HumanString(lhs));
1660   }
1661   if (rhs.rank() != num_dims) {
1662     return InvalidArgument(
1663         "The RHS argument to a convolution should have rank %d; rhs: %s.",
1664         num_dims, ShapeUtil::HumanString(rhs));
1665   }
1666   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1667   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1668 
1669   // Verifies that the input and window dimensions are a permutation of
1670   // the dimension numbers.
1671   std::vector<int64> input_dnums(num_dims);
1672   input_dnums[0] = dnums.input_batch_dimension();
1673   input_dnums[1] = dnums.input_feature_dimension();
1674   std::copy(dnums.input_spatial_dimensions().begin(),
1675             dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1676   absl::c_sort(input_dnums);
1677 
1678   std::vector<int64> window_dnums(num_dims);
1679   window_dnums[0] = dnums.kernel_input_feature_dimension();
1680   window_dnums[1] = dnums.kernel_output_feature_dimension();
1681   std::copy(dnums.kernel_spatial_dimensions().begin(),
1682             dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1683   absl::c_sort(window_dnums);
1684 
1685   std::vector<int64> output_dnums(num_dims);
1686   output_dnums[0] = dnums.output_batch_dimension();
1687   output_dnums[1] = dnums.output_feature_dimension();
1688   std::copy(dnums.output_spatial_dimensions().begin(),
1689             dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1690   absl::c_sort(output_dnums);
1691 
1692   std::vector<int64> expected_dnums(num_dims);
1693   std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1694 
1695   const auto in_range = [num_dims](int64_t i) {
1696     return 0 <= i && i < num_dims;
1697   };
1698   if (!absl::c_all_of(input_dnums, in_range) ||
1699       !absl::c_all_of(window_dnums, in_range) ||
1700       !absl::c_all_of(output_dnums, in_range)) {
1701     return InvalidArgument(
1702         "A dimension number is out of range in convolution: %s.",
1703         dnums.DebugString());
1704   }
1705 
1706   if (input_dnums != expected_dnums) {
1707     return InvalidArgument(
1708         "Input dimensions of convolution must contain each dimension exactly "
1709         "once: %s.",
1710         dnums.DebugString());
1711   }
1712   if (window_dnums != expected_dnums) {
1713     return InvalidArgument(
1714         "Window dimensions of convolution must contain each dimension exactly "
1715         "once: %s.",
1716         dnums.DebugString());
1717   }
1718   if (output_dnums != expected_dnums) {
1719     return InvalidArgument(
1720         "Output dimensions of convolution must contain each dimension exactly "
1721         "once: %s.",
1722         dnums.DebugString());
1723   }
1724 
1725   std::vector<int64> input_spatial_dims(num_spatial_dims);
1726   for (int i = 0; i < num_spatial_dims; ++i) {
1727     input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1728   }
1729   const int64_t input_features =
1730       lhs.dimensions(dnums.input_feature_dimension());
1731   const int64_t input_batch = lhs.dimensions(dnums.input_batch_dimension());
1732 
1733   std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1734   for (int i = 0; i < num_spatial_dims; ++i) {
1735     kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1736   }
1737   const int64_t kernel_input_features =
1738       rhs.dimensions(dnums.kernel_input_feature_dimension());
1739   const int64_t kernel_output_features =
1740       rhs.dimensions(dnums.kernel_output_feature_dimension());
1741 
1742   if (kernel_output_features % batch_group_count != 0) {
1743     return InvalidArgument(
1744         "Expected output feature dimension size (value %d) to be a multiple of "
1745         "batch group count %d; got <conv>(%s, %s)\n"
1746         "Dimension numbers: {%s}.",
1747         kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1748         ShapeUtil::HumanString(rhs), dnums.DebugString());
1749   }
1750 
1751   if (input_features % feature_group_count != 0 ||
1752       input_features / feature_group_count != kernel_input_features) {
1753     return InvalidArgument(
1754         "Expected LHS feature dimension (value %d) to be a multiple of "
1755         "feature_group_count (value %d), and LHS feature dimension / "
1756         "feature_group_count = RHS feature dimension (value %d); "
1757         "got <conv>(%s, %s)\n"
1758         "Dimension numbers: {%s}.",
1759         input_features, feature_group_count, kernel_input_features,
1760         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1761         dnums.DebugString());
1762   }
1763 
1764   if (kernel_output_features % feature_group_count > 0) {
1765     // A depthwise/grouped filter has the shape
1766     // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1767     // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1768     // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1769     // feature count (which is equal to kernel output features) has to be a
1770     // multiple of feature_group_count.
1771     return InvalidArgument(
1772         "Expected output feature dimension (value %d) to be divisible by "
1773         "feature_group_count (value %d); "
1774         "got <conv>(%s, %s)\n"
1775         "Dimension numbers: {%s}.",
1776         kernel_output_features, feature_group_count,
1777         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1778         dnums.DebugString());
1779   }
1780 
1781   if (input_batch % batch_group_count != 0) {
1782     return InvalidArgument(
1783         "Expected input batch dimension (value %d) to be divisible by "
1784         "batch_group_count (value %d); "
1785         "got <conv>(%s, %s)\n"
1786         "Dimension numbers: {%s}.",
1787         input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1788         ShapeUtil::HumanString(rhs), dnums.DebugString());
1789   }
1790 
1791   std::vector<int64> window_dims(num_spatial_dims);
1792   for (int i = 0; i < num_spatial_dims; ++i) {
1793     window_dims[i] = window.dimensions(i).size();
1794   }
1795   if (kernel_spatial_dims != window_dims) {
1796     return InvalidArgument(
1797         "Window dimensions do not match RHS shape:\n\t"
1798         "RHS shape: %s\n\t"
1799         "Window: {%s}\n\t"
1800         "Dimension numbers: {%s}.",
1801         ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1802         dnums.ShortDebugString());
1803   }
1804 
1805   Shape base_shape =
1806       ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1807   TF_ASSIGN_OR_RETURN(
1808       Shape window_output_shape,
1809       InferWindowOutputShape(base_shape, window, lhs.element_type(),
1810                              /*allow_negative_padding=*/true));
1811 
1812   std::vector<int64> dimensions(num_dims);
1813   dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1814   dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1815 
1816   for (int i = 0; i < num_spatial_dims; ++i) {
1817     dimensions[dnums.output_spatial_dimensions(i)] =
1818         window_output_shape.dimensions(i);
1819   }
1820   std::vector<bool> is_dynamic(num_dims);
1821   for (int i = 0; i < num_dims; i++) {
1822     if (lhs.is_dynamic_dimension(i)) {
1823       if (i == dnums.input_batch_dimension()) {
1824         is_dynamic[dnums.output_batch_dimension()] = true;
1825       } else if (i == dnums.input_feature_dimension()) {
1826         // Input feature dimension is a contracting dimension, which does not
1827         // affect the output dimension size. So we need to do nothing.
1828       } else {
1829         for (int64_t j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
1830           if (i == dnums.input_spatial_dimensions(j)) {
1831             // i is a spatial dimension, find corresponding output spatial
1832             // dimension.
1833             is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1834           }
1835         }
1836       }
1837     }
1838     if (rhs.is_dynamic_dimension(i)) {
1839       if (i == dnums.kernel_input_feature_dimension()) {
1840         // Kernel feature dimension does not affect the output dimension size.
1841         // So we need to do nothing.
1842       } else if (i == dnums.kernel_output_feature_dimension()) {
1843         return InvalidArgument(
1844             "Dynamic output feature dim on convolution kernel is not "
1845             "supported: rhs shape is %s ",
1846             rhs.ToString());
1847       } else {
1848         for (int64_t j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
1849           if (i == dnums.kernel_spatial_dimensions(j)) {
1850             // i is a spatial dimension, find corresponding output spatial
1851             // dimension.
1852             is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1853           }
1854         }
1855       }
1856     }
1857   }
1858   TF_ASSIGN_OR_RETURN(
1859       PrimitiveType type,
1860       MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1861                   preferred_element_type));
1862   return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
1863 }
1864 
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1865 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1866     const Shape& in, const FftType fft_type,
1867     const absl::Span<const int64> fft_length) {
1868   const int64_t fft_rank = fft_length.size();
1869   if (fft_rank < 1 || fft_rank > 3) {
1870     return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1871   }
1872 #define RET_CHECK_RANK(x)                            \
1873   if (x.dimensions_size() < fft_rank) {              \
1874     return InvalidArgument(                          \
1875         "FFT of rank %d requires input of at least " \
1876         "same rank; got input of rank %d",           \
1877         fft_rank, x.dimensions_size());              \
1878   }
1879   switch (fft_type) {
1880     case FFT:
1881     case IFFT:
1882       if (!primitive_util::IsComplexType(in.element_type())) {
1883         return InvalidArgument("%s requires complex input type, found %s.",
1884                                FftType_Name(fft_type),
1885                                PrimitiveType_Name(in.element_type()));
1886       }
1887       RET_CHECK_RANK(in);
1888       return in;
1889     case RFFT: {
1890       if (in.element_type() != F32 && in.element_type() != F64) {
1891         return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
1892                                PrimitiveType_Name(in.element_type()));
1893       }
1894       RET_CHECK_RANK(in);
1895       for (int i = 0; i < fft_rank; i++) {
1896         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1897             fft_length[i]) {
1898           return InvalidArgument(
1899               "RFFT requires innermost dimensions match fft_length but "
1900               "dimension %d is %d and should be %d.",
1901               in.dimensions_size() - fft_rank + i,
1902               in.dimensions(in.dimensions_size() - fft_rank + i),
1903               fft_length[i]);
1904         }
1905       }
1906       Shape result = ShapeUtil::ChangeElementType(
1907           in, in.element_type() == F32 ? C64 : C128);
1908       // Preserve the size of zero-sized dimensions.
1909       if (fft_length[fft_rank - 1] != 0) {
1910         result.set_dimensions(result.dimensions_size() - 1,
1911                               fft_length[fft_rank - 1] / 2 + 1);
1912       }
1913       return result;
1914     }
1915     case IRFFT: {
1916       if (!primitive_util::IsComplexType(in.element_type())) {
1917         return InvalidArgument("IRFFT requires complex input type, found %s.",
1918                                PrimitiveType_Name(in.element_type()));
1919       }
1920       RET_CHECK_RANK(in);
1921       Shape result = ShapeUtil::ComplexComponentShape(in);
1922       for (int i = 0; i < fft_rank - 1; i++) {
1923         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1924             fft_length[i]) {
1925           return InvalidArgument(
1926               "IRFFT requires all but one innermost dimensions match "
1927               "fft_length, but dimension %d is %d and should be %d.",
1928               in.dimensions_size() - fft_rank + i,
1929               in.dimensions(in.dimensions_size() - fft_rank + i),
1930               fft_length[i]);
1931         }
1932       }
1933       // The size of zero-sized dimensions is preserved.
1934       if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1935            fft_length[fft_rank - 1] != 0) &&
1936           in.dimensions(in.dimensions_size() - 1) !=
1937               fft_length[fft_rank - 1] / 2 + 1) {
1938         return InvalidArgument(
1939             "IRFFT requires innermost dimension matches fft_length/2+1, but "
1940             "dimension %d is %d and should be %d.",
1941             in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1942             fft_length[fft_rank - 1] / 2 + 1);
1943       }
1944       result.set_dimensions(result.dimensions_size() - 1,
1945                             fft_length[fft_rank - 1]);
1946       return result;
1947     }
1948     default:
1949       LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1950   }
1951 #undef RET_CHECK_RANK
1952 }
1953 
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1954 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1955     const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1956   if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1957       a.element_type() != b.element_type()) {
1958     return InvalidArgument(
1959         "Expected element types in shape to be floating or complex and "
1960         "identical for TriangularSolve; got %s and %s.",
1961         PrimitiveType_Name(a.element_type()),
1962         PrimitiveType_Name(b.element_type()));
1963   }
1964   if (a.rank() < 2) {
1965     return InvalidArgument(
1966         "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1967         a.ToString());
1968   }
1969   if (b.rank() != a.rank()) {
1970     return InvalidArgument(
1971         "Arguments to triangular solve must have equal rank; got %s and %s.",
1972         b.ToString(), a.ToString());
1973   }
1974   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1975     return InvalidArgument(
1976         "The two minor dimensions of 'a' must have equal size, got %s.",
1977         a.ToString());
1978   }
1979   if (a.dimensions(a.rank() - 1) !=
1980       b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1981     return InvalidArgument(
1982         "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1983         "%s",
1984         a.ToString(), b.ToString());
1985   }
1986   absl::Span<const int64> a_batch_dims(a.dimensions());
1987   absl::Span<const int64> b_batch_dims(b.dimensions());
1988   a_batch_dims.remove_suffix(2);
1989   b_batch_dims.remove_suffix(2);
1990   if (a_batch_dims != b_batch_dims) {
1991     return InvalidArgument(
1992         "The leading batch dimensions of the arguments to triangular solve "
1993         "must be equal; got %s and %s.",
1994         b.ToString(), a.ToString());
1995   }
1996   if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
1997       options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
1998     return InvalidArgument(
1999         "Invalid transpose option value for triangular solve (%d).\n",
2000         options.transpose_a());
2001   }
2002   return b;
2003 }
2004 
InferCholeskyShape(const Shape & a)2005 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
2006     const Shape& a) {
2007   if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
2008     return InvalidArgument(
2009         "Expected element type in shape to be floating or complex for "
2010         "Cholesky; got %s.",
2011         PrimitiveType_Name(a.element_type()));
2012   }
2013   if (a.rank() < 2) {
2014     return InvalidArgument(
2015         "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
2016         a.ToString());
2017   }
2018   if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
2019     return InvalidArgument(
2020         "The two minor dimensions of 'a' must have equal size, got %s.",
2021         a.ToString());
2022   }
2023   return a;
2024 }
2025 
InferAllGatherShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2026 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
2027     absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2028     int64_t shard_count) {
2029   TF_RET_CHECK(all_gather_dimension >= 0);
2030   TF_RET_CHECK(shard_count > 0);
2031 
2032   std::vector<Shape> output_shapes;
2033   output_shapes.reserve(operand_shapes.size());
2034   for (const Shape* operand_shape : operand_shapes) {
2035     TF_RET_CHECK(all_gather_dimension < operand_shape->rank());
2036     TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather"));
2037 
2038     Shape output_shape = *operand_shape;
2039     output_shape.set_dimensions(
2040         all_gather_dimension,
2041         shard_count * output_shape.dimensions(all_gather_dimension));
2042     output_shapes.push_back(output_shape);
2043   }
2044   if (output_shapes.size() == 1) {
2045     return output_shapes[0];
2046   }
2047   return ShapeUtil::MakeTupleShape(output_shapes);
2048 }
2049 
InferAllGatherStartShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2050 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherStartShape(
2051     absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2052     int64_t shard_count) {
2053   TF_ASSIGN_OR_RETURN(
2054       Shape ag_shape,
2055       InferAllGatherShape(operand_shapes, all_gather_dimension, shard_count));
2056   Shape input_shape;
2057   std::vector<Shape> op_shapes;
2058   op_shapes.reserve(operand_shapes.size());
2059   for (const Shape* shp : operand_shapes) {
2060     op_shapes.push_back(*shp);
2061   }
2062   if (op_shapes.size() == 1) {
2063     input_shape = op_shapes[0];
2064   } else {
2065     input_shape = ShapeUtil::MakeTupleShape(op_shapes);
2066   }
2067   return ShapeUtil::MakeTupleShape({input_shape, ag_shape});
2068 }
2069 
InferAllGatherDoneShape(const Shape & all_gather_start_shape)2070 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherDoneShape(
2071     const Shape& all_gather_start_shape) {
2072   return ShapeUtil::GetTupleElementShape(all_gather_start_shape, 1);
2073 }
2074 
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2075 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2076     absl::Span<const Shape* const> operand_shapes) {
2077   for (const Shape* operand_shape : operand_shapes) {
2078     TF_RETURN_IF_ERROR(
2079         ExpectArray(*operand_shape, "operand of cross replica sum"));
2080   }
2081   if (operand_shapes.size() == 1) {
2082     return *operand_shapes[0];
2083   }
2084   std::vector<Shape> operand_shape_values;
2085   for (const Shape* operand_shape : operand_shapes) {
2086     operand_shape_values.push_back(*operand_shape);
2087   }
2088   return ShapeUtil::MakeTupleShape(operand_shape_values);
2089 }
2090 
InferReduceScatterShape(absl::Span<const Shape * const> operand_shapes,int64_t scatter_dimension,int64_t shard_count)2091 /* static */ StatusOr<Shape> ShapeInference::InferReduceScatterShape(
2092     absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
2093     int64_t shard_count) {
2094   TF_RET_CHECK(scatter_dimension >= 0);
2095   TF_RET_CHECK(shard_count > 0);
2096 
2097   std::vector<Shape> output_shapes;
2098   output_shapes.reserve(operand_shapes.size());
2099   for (const Shape* operand_shape : operand_shapes) {
2100     TF_RET_CHECK(scatter_dimension < operand_shape->rank());
2101     TF_RETURN_IF_ERROR(
2102         ExpectArray(*operand_shape, "operand of reduce-scatter"));
2103 
2104     int64_t scatter_dim_input_size =
2105         operand_shape->dimensions(scatter_dimension);
2106     if (scatter_dim_input_size % shard_count != 0) {
2107       return InvalidArgument(
2108           "ReduceScatter operand scatter dimension size %d must be "
2109           "dividable by shard_count "
2110           "%d.",
2111           scatter_dim_input_size, shard_count);
2112     }
2113 
2114     Shape output_shape = *operand_shape;
2115     output_shape.set_dimensions(scatter_dimension,
2116                                 scatter_dim_input_size / shard_count);
2117     output_shapes.push_back(output_shape);
2118   }
2119 
2120   if (output_shapes.size() == 1) {
2121     return output_shapes[0];
2122   }
2123   return ShapeUtil::MakeTupleShape(output_shapes);
2124 }
2125 
InferAllReduceStartShape(absl::Span<const Shape * const> operand_shapes)2126 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceStartShape(
2127     absl::Span<const Shape* const> operand_shapes) {
2128   TF_ASSIGN_OR_RETURN(Shape shape, InferAllReduceShape(operand_shapes));
2129 
2130   return ShapeUtil::MakeTupleShape({shape, shape});
2131 }
2132 
InferAllReduceDoneShape(const Shape & operand_shape)2133 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceDoneShape(
2134     const Shape& operand_shape) {
2135   // The returned value from AllReduceDone is determined from
2136   // HloDataflowAnalysis::UpdateAllReduceStartValueSet(). The operand to
2137   // AllReduceDone is a tuple of two elements and this function selects
2138   // ShapeIndex {1} as the value forwarded.
2139   return ShapeUtil::GetTupleElementShape(operand_shape, 1);
2140 }
2141 
InferAllToAllShape(const Shape & shape,int64_t split_dimension,int64_t concat_dimension,int64_t split_count)2142 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2143     const Shape& shape, int64_t split_dimension, int64_t concat_dimension,
2144     int64_t split_count) {
2145   TF_RET_CHECK(split_count > 0);
2146   if (split_dimension >= shape.rank() || split_dimension < 0) {
2147     return InvalidArgument(
2148         "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2149         split_dimension, ShapeUtil::HumanString(shape));
2150   }
2151   if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2152     return InvalidArgument(
2153         "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2154         concat_dimension, ShapeUtil::HumanString(shape));
2155   }
2156   if (shape.dimensions(split_dimension) % split_count != 0) {
2157     return InvalidArgument(
2158         "AllToAll split dimension size %d must be dividable by split_count "
2159         "%d.",
2160         shape.dimensions(split_dimension), split_count);
2161   }
2162   std::vector<int64> new_dimensions(shape.dimensions().begin(),
2163                                     shape.dimensions().end());
2164   new_dimensions[split_dimension] /= split_count;
2165   new_dimensions[concat_dimension] *= split_count;
2166   return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2167 }
2168 
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2169 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2170     absl::Span<const Shape* const> operand_shapes) {
2171   // An Alltoall HLO instruction receives N operands (with the same shape) and
2172   // returns a tuple that contains N array shapes.
2173   TF_RET_CHECK(!operand_shapes.empty());
2174   for (int i = 0; i < operand_shapes.size(); i++) {
2175     if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
2176                                                     *operand_shapes[i])) {
2177       return InvalidArgument(
2178           "HLO all-to-all has operands with different shapes: the 0th "
2179           "operand shape %s, but the %dth operand has shape %s.",
2180           ShapeUtil::HumanString(*operand_shapes[0]), i,
2181           ShapeUtil::HumanString(*operand_shapes[i]));
2182     }
2183   }
2184 
2185   return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2186 }
2187 
InferCollectivePermuteShape(absl::Span<const Shape * const> operand_shapes)2188 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2189     absl::Span<const Shape* const> operand_shapes) {
2190   if (operand_shapes.size() == 1) {
2191     TF_RETURN_IF_ERROR(
2192         ExpectArray(*(operand_shapes[0]), "operand of collective-permute"));
2193     return *(operand_shapes[0]);
2194   } else {
2195     TF_RET_CHECK(operand_shapes.size() == 4);
2196     return *(operand_shapes[1]);
2197   }
2198 }
2199 
InferCollectivePermuteStartShape(absl::Span<const Shape * const> operand_shapes)2200 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteStartShape(
2201     absl::Span<const Shape* const> operand_shapes) {
2202   if (operand_shapes.size() == 1) {
2203     TF_RETURN_IF_ERROR(ExpectArray(*(operand_shapes[0]),
2204                                    "operand of collective-permute-start"));
2205     return ShapeUtil::MakeTupleShape(
2206         {*(operand_shapes[0]), *(operand_shapes[0]),
2207          ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
2208   } else {
2209     TF_RET_CHECK(operand_shapes.size() == 4);
2210     return ShapeUtil::MakeTupleShape(
2211         {*(operand_shapes[0]), *(operand_shapes[1]),
2212          ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {}),
2213          ShapeUtil::MakeShape(S32, {})});
2214   }
2215 }
2216 
InferCollectivePermuteDoneShape(const Shape & operand_shape)2217 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteDoneShape(
2218     const Shape& operand_shape) {
2219   TF_RET_CHECK(operand_shape.IsTuple());
2220   if (operand_shape.tuple_shapes_size() == 4) {
2221     return ShapeUtil::GetTupleElementShape(operand_shape, 0);
2222   } else {
2223     return ShapeUtil::GetTupleElementShape(operand_shape, 1);
2224   }
2225 }
2226 
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2227 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2228     absl::Span<const Shape* const> arg_shapes,
2229     absl::Span<const int64> dimensions_to_reduce,
2230     const ProgramShape& to_apply) {
2231   if (arg_shapes.empty()) {
2232     return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2233   }
2234   if (arg_shapes.size() % 2) {
2235     return InvalidArgument(
2236         "Reduce must have an even number of arguments, has %lu",
2237         arg_shapes.size());
2238   }
2239   int64_t num_reduced_args = arg_shapes.size() / 2;
2240   auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2241   // Check that all of the reduced tensors have the same dimensions. The element
2242   // types may be different.
2243   for (int64_t i = 1; i < num_reduced_args; ++i) {
2244     if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2245       return InvalidArgument(
2246           "All reduced tensors must have the same dimension. Tensor 0 has "
2247           "shape %s, Tensor %d has shape %s",
2248           ShapeUtil::HumanString(*reduced_args[0]), i,
2249           ShapeUtil::HumanString(*reduced_args[i]));
2250     }
2251   }
2252   // Check that the dimensions to reduce are in-bounds for the given shape.
2253   // We've already verified all reduced tensors have the same dimensions, so it
2254   // doesn't matter which one we choose.
2255   const Shape& arg = *reduced_args[0];
2256   for (int64_t dimension : dimensions_to_reduce) {
2257     if (dimension >= arg.rank() || dimension < 0) {
2258       return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2259                              dimension, ShapeUtil::HumanString(arg));
2260     }
2261   }
2262 
2263   auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2264   std::vector<PrimitiveType> element_types;
2265   for (const Shape* arg : reduced_args) {
2266     element_types.push_back(arg->element_type());
2267   }
2268   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2269                                         num_reduced_args));
2270 
2271   absl::flat_hash_set<int64> dimensions_to_reduce_set;
2272   for (int64_t dim_to_reduce : dimensions_to_reduce) {
2273     if (!dimensions_to_reduce_set.insert(dim_to_reduce).second) {
2274       return InvalidArgument("Duplicate reduction dimension: %d",
2275                              dim_to_reduce);
2276     }
2277   }
2278 
2279   std::vector<int64> new_dimensions;
2280   std::vector<bool> new_is_dynamic;
2281   for (int i = 0; i < arg.rank(); ++i) {
2282     if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2283       new_dimensions.push_back(arg.dimensions(i));
2284       new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2285     }
2286   }
2287 
2288   if (ShapeUtil::IsScalar(to_apply.result())) {
2289     return ShapeUtil::MakeShape(to_apply.result().element_type(),
2290                                 new_dimensions, new_is_dynamic);
2291   } else {
2292     std::vector<Shape> result_subshapes;
2293     for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2294       result_subshapes.push_back(ShapeUtil::MakeShape(
2295           subshape.element_type(), new_dimensions, new_is_dynamic));
2296     }
2297     return ShapeUtil::MakeTupleShape(result_subshapes);
2298   }
2299 }
2300 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2301 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2302     const Shape& operand_shape, const Shape& init_value_shape,
2303     const Window& window, const ProgramShape& to_apply_shape) {
2304   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2305                                         {operand_shape.element_type()},
2306                                         /*inputs=*/1));
2307   return InferReduceWindowShape(operand_shape, init_value_shape, window);
2308 }
2309 
InferReduceWindowShape(absl::Span<const Shape * const> operands,absl::Span<const Shape * const> init_values,const Window & window,const ProgramShape & to_apply_shape)2310 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2311     absl::Span<const Shape* const> operands,
2312     absl::Span<const Shape* const> init_values, const Window& window,
2313     const ProgramShape& to_apply_shape) {
2314   auto number_of_input = operands.size();
2315   // Check that all of the reduced tensors have the same dimensions. The element
2316   // types may be different.
2317   for (int64_t i = 1; i < number_of_input; ++i) {
2318     if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
2319       return InvalidArgument(
2320           "All reduced tensors must have the same dimension. Tensor 0 has "
2321           "shape %s, Tensor %d has shape %s",
2322           ShapeUtil::HumanString(*operands[0]), i,
2323           ShapeUtil::HumanString(*operands[i]));
2324     }
2325   }
2326   std::vector<PrimitiveType> operand_element_type_vec;
2327   for (const Shape* s : operands) {
2328     operand_element_type_vec.push_back(s->element_type());
2329   }
2330   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
2331                                         operand_element_type_vec,
2332                                         /*inputs=*/number_of_input));
2333   std::vector<Shape> output_shape_vec;
2334   for (int i = 0; i < operands.size(); ++i) {
2335     TF_ASSIGN_OR_RETURN(
2336         auto cur_output_shape,
2337         InferReduceWindowShape(*operands[i], *init_values[i], window));
2338     output_shape_vec.push_back(cur_output_shape);
2339   }
2340   if (ShapeUtil::IsScalar(to_apply_shape.result())) {
2341     CHECK_EQ(output_shape_vec.size(), 1);
2342     return output_shape_vec[0];
2343   } else {
2344     return ShapeUtil::MakeTupleShape(output_shape_vec);
2345   }
2346 }
2347 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2348 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2349     const Shape& operand_shape, const Shape& init_value_shape,
2350     const Window& window) {
2351   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2352   return InferWindowOutputShape(operand_shape, window,
2353                                 init_value_shape.element_type(),
2354                                 /*allow_negative_padding=*/false);
2355 }
2356 
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2357 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2358     const Shape& operand_shape, const ProgramShape& select_shape,
2359     const Window& window, const Shape& source_shape,
2360     const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2361   TF_RETURN_IF_ERROR(
2362       ExpectArray(operand_shape, "operand of select-and-scatter"));
2363 
2364   // Check if the select function has a proper shape of (T,T) -> PRED.
2365   if (select_shape.parameters_size() != 2) {
2366     return InvalidArgument(
2367         "Select function must take 2 parameters, but "
2368         "takes %d parameter(s).",
2369         select_shape.parameters_size());
2370   }
2371   const Shape& select_result_shape = select_shape.result();
2372   if (!ShapeUtil::Compatible(select_result_shape,
2373                              ShapeUtil::MakeShape(PRED, {}))) {
2374     return InvalidArgument("Select function must have rank-0 PRED result.");
2375   }
2376   const Shape& operand_element_shape =
2377       ShapeUtil::MakeShape(operand_shape.element_type(), {});
2378   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2379                                                 select_shape.parameters(0))) {
2380     return InvalidArgument(
2381         "Select function's first parameter shape currently must "
2382         "match the operand element shape, but got %s vs %s.",
2383         ShapeUtil::HumanString(select_shape.parameters(0)),
2384         ShapeUtil::HumanString(operand_element_shape));
2385   }
2386   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2387                                                 select_shape.parameters(1))) {
2388     return InvalidArgument(
2389         "Select function's second parameter shape currently must "
2390         "match the operand element shape, but got %s vs %s.",
2391         ShapeUtil::HumanString(select_shape.parameters(1)),
2392         ShapeUtil::HumanString(operand_element_shape));
2393   }
2394 
2395   // Check if the scatter function has a proper shape as a reduction.
2396   TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2397                                         {source_shape.element_type()},
2398                                         /*inputs=*/1));
2399 
2400   // Check if the result shape of window operation matches the source shape.
2401   TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2402                       InferWindowOutputShape(operand_shape, window,
2403                                              operand_shape.element_type(),
2404                                              /*allow_negative_padding=*/false));
2405   if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2406                                                 window_result_shape)) {
2407     return InvalidArgument(
2408         "Source shape does not match the shape of window-reduced operand: "
2409         "source(%s), window-reduced operand(%s).",
2410         ShapeUtil::HumanString(source_shape),
2411         ShapeUtil::HumanString(window_result_shape));
2412   }
2413 
2414   return operand_shape;
2415 }
2416 
InferGetDimensionSizeShape(const Shape & shape,int64_t dimension)2417 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2418     const Shape& shape, int64_t dimension) {
2419   if (dimension < 0 || dimension >= shape.rank()) {
2420     return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2421                            dimension);
2422   }
2423 
2424   // TODO(b/119580730): Remove this restriction when very large dimension size
2425   // is needed.
2426   if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2427     return InvalidArgument(
2428         "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2429         "INT_MAX limit.",
2430         ShapeUtil::HumanString(shape), dimension);
2431   }
2432 
2433   return ShapeUtil::MakeShape(S32, {});
2434 }
2435 
InferSetDimensionSizeShape(const Shape & shape,const Shape & val_shape,int64_t dimension)2436 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2437     const Shape& shape, const Shape& val_shape, int64_t dimension) {
2438   if (dimension < 0 || dimension >= shape.rank()) {
2439     return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2440                            dimension);
2441   }
2442 
2443   if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
2444     return InvalidArgument(
2445         "SetDimensionSize's value has to be S32 scalar, got %s",
2446         val_shape.ToString());
2447   }
2448   // TODO(b/119580730): Remove this restriction when very large dimension size
2449   // is needed.
2450   if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2451     return InvalidArgument(
2452         "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2453         "INT_MAX limit.",
2454         ShapeUtil::HumanString(shape), dimension);
2455   }
2456 
2457   Shape result = shape;
2458   result.set_dynamic_dimension(dimension, true);
2459   return result;
2460 }
2461 
InferWindowFromDimensions(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation)2462 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2463     absl::Span<const int64> window_dimensions,
2464     absl::Span<const int64> window_strides,
2465     absl::Span<const std::pair<int64, int64>> padding,
2466     absl::Span<const int64> lhs_dilation,
2467     absl::Span<const int64> rhs_dilation) {
2468   const auto verify_size = [&](const size_t x, const char* x_name) {
2469     if (x == 0 || x == window_dimensions.size()) {
2470       return Status::OK();
2471     } else {
2472       return InvalidArgument(
2473           "%s", absl::StrCat(
2474                     "Window has different number of window dimensions than of ",
2475                     x_name,
2476                     "\nNumber of window dimensions: ", window_dimensions.size(),
2477                     "\nNumber of ", x_name, ": ", x, "\n"));
2478     }
2479   };
2480   TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2481   TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2482   TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2483   TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2484 
2485   Window window;
2486   for (size_t i = 0; i < window_dimensions.size(); i++) {
2487     auto dim = window.add_dimensions();
2488     dim->set_size(window_dimensions[i]);
2489     if (!window_strides.empty()) {
2490       dim->set_stride(window_strides[i]);
2491     } else {
2492       dim->set_stride(1);
2493     }
2494     if (!padding.empty()) {
2495       dim->set_padding_low(padding[i].first);
2496       dim->set_padding_high(padding[i].second);
2497     } else {
2498       dim->set_padding_low(0);
2499       dim->set_padding_high(0);
2500     }
2501     if (!lhs_dilation.empty()) {
2502       dim->set_base_dilation(lhs_dilation[i]);
2503     } else {
2504       dim->set_base_dilation(1);
2505     }
2506     if (!rhs_dilation.empty()) {
2507       dim->set_window_dilation(rhs_dilation[i]);
2508     } else {
2509       dim->set_window_dilation(1);
2510     }
2511     dim->set_window_reversal(false);
2512   }
2513   return window;
2514 }
2515 
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2516 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2517     const Shape& arg, absl::Span<const int64> starts,
2518     absl::Span<const int64> limits, absl::Span<const int64> strides) {
2519   auto error = [&](const string& message) {
2520     return InvalidArgument(
2521         "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2522         "{%s}; strides: {%s}.",
2523         message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2524         StrJoin(limits, ","), StrJoin(strides, ","));
2525   };
2526   TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2527   VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2528                        ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2529                        StrJoin(limits, ", "));
2530 
2531   if (starts.size() != limits.size()) {
2532     return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2533                            starts.size(), limits.size()));
2534   }
2535 
2536   if (starts.size() != strides.size()) {
2537     return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2538                            starts.size(), strides.size()));
2539   }
2540 
2541   if (starts.size() != arg.rank()) {
2542     return InvalidArgument(
2543         "Slice index count does not match argument rank: %u vs %d.",
2544         starts.size(), arg.rank());
2545   }
2546 
2547   std::vector<int64> sizes;
2548   for (int64_t dimension = 0; dimension < starts.size(); ++dimension) {
2549     int64_t start_index = starts[dimension];
2550     int64_t limit_index = limits[dimension];
2551     int64_t stride = strides[dimension];
2552     if (start_index < 0) {
2553       return InvalidArgument("Negative start index to slice: %d.", start_index);
2554     }
2555     if (limit_index > arg.dimensions(dimension)) {
2556       return error(
2557           StrFormat("limit index (%d) must be less than or equal to dimension "
2558                     "size (%d)",
2559                     limit_index, arg.dimensions(dimension)));
2560     }
2561     VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2562     VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2563     if (start_index > limit_index) {
2564       return error(
2565           StrFormat("limit index (%d) must be greater or equal to "
2566                     "start index (%d) in slice with positive stride",
2567                     limit_index, start_index));
2568     }
2569     if (stride <= 0) {
2570       return InvalidArgument("Stride (%d) must be positive.", stride);
2571     }
2572     sizes.push_back((limit_index - start_index + stride - 1) / stride);
2573   }
2574 
2575   std::vector<bool> is_dynamic(arg.rank());
2576   for (int64_t i = 0; i < arg.dimensions_size(); ++i) {
2577     // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2578     if (sizes[i] == 1) {
2579       continue;
2580     }
2581     is_dynamic[i] = arg.is_dynamic_dimension(i);
2582   }
2583 
2584   return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2585 }
2586 
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2587 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2588     const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2589     absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2590   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2591   auto number_of_indices = start_index_shapes.size();
2592   // TODO(b/118437727): Remove this path.
2593   if (!allow_scalar_indices ||
2594       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2595     if (number_of_indices != 1) {
2596       return InvalidArgument(
2597           "Dynamic slice should have exactly 1 index operand, has %d.",
2598           number_of_indices);
2599     }
2600 
2601     const Shape& start_indices_shape = start_index_shapes[0];
2602     VLOG(2) << StrFormat(
2603         "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2604         ShapeUtil::HumanString(operand_shape),
2605         ShapeUtil::HumanString(start_indices_shape),
2606         StrJoin(slice_sizes, ", "));
2607 
2608     TF_RETURN_IF_ERROR(
2609         ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2610 
2611     if (start_indices_shape.rank() != 1) {
2612       return InvalidArgument(
2613           "Dynamic slice start indices of rank %d must be rank1.",
2614           start_indices_shape.rank());
2615     }
2616 
2617     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2618       return InvalidArgument(
2619           "Dynamic slice start indices must be of integral type.");
2620     }
2621 
2622     const int64_t start_num_dims = start_indices_shape.dimensions(0);
2623     if (operand_shape.rank() != start_num_dims) {
2624       return InvalidArgument(
2625           "Dynamic slice start number of dimensions %d (%s) must match rank "
2626           "%d of slice input (%s).",
2627           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2628           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2629     }
2630   } else {
2631     VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2632                          ShapeUtil::HumanString(operand_shape),
2633                          StrJoin(slice_sizes, ", "));
2634 
2635     if (operand_shape.rank() != number_of_indices) {
2636       return InvalidArgument(
2637           "Dynamic slice start number of dimensions %d must match rank "
2638           "%d of slice input (%s).",
2639           number_of_indices, operand_shape.rank(),
2640           ShapeUtil::HumanString(operand_shape));
2641     }
2642 
2643     if (number_of_indices > 0) {
2644       const Shape& first_index_shape = start_index_shapes[0];
2645       if (!ShapeUtil::IsScalar(first_index_shape)) {
2646         return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2647                                ShapeUtil::HumanString(first_index_shape));
2648       }
2649       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2650         return InvalidArgument(
2651             "Dynamic slice start indices must be of integral type.");
2652       }
2653       for (const Shape& index_shape : start_index_shapes) {
2654         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2655           return InvalidArgument(
2656               "Dynamic slice start indices must all have the same shape, got "
2657               "mismatching indices with shapes %s and %s.",
2658               ShapeUtil::HumanString(first_index_shape),
2659               ShapeUtil::HumanString(index_shape));
2660         }
2661       }
2662     }
2663   }
2664 
2665   if (slice_sizes.size() != operand_shape.rank()) {
2666     return InvalidArgument(
2667         "Dynamic slice index count does not match argument rank: %u vs %d.",
2668         slice_sizes.size(), operand_shape.rank());
2669   }
2670 
2671   for (int64_t dim = 0; dim < slice_sizes.size(); ++dim) {
2672     const int64_t input_dim_size = operand_shape.dimensions(dim);
2673     const int64_t slice_dim_size = slice_sizes[dim];
2674     if (slice_dim_size < 0) {
2675       return InvalidArgument("Negative size index to dynamic slice: %d.",
2676                              slice_dim_size);
2677     }
2678     if (slice_dim_size > input_dim_size) {
2679       return InvalidArgument(
2680           "Slice dim size %d greater than dynamic slice dimension: %d.",
2681           slice_dim_size, input_dim_size);
2682     }
2683     VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2684   }
2685 
2686   return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2687 }
2688 
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2689 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2690     const Shape& operand_shape, const Shape& update_shape,
2691     absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2692   TF_RETURN_IF_ERROR(
2693       ExpectArray(operand_shape, "operand of dynamic update slice"));
2694   TF_RETURN_IF_ERROR(
2695       ExpectArray(update_shape, "update of dynamic update slice"));
2696 
2697   auto number_of_indices = start_index_shapes.size();
2698   // TODO(b/118437727): Remove this path.
2699   if (!allow_scalar_indices ||
2700       (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2701     if (number_of_indices != 1) {
2702       return InvalidArgument(
2703           "Dynamic update slice should have exactly 1 index operand, has %d.",
2704           number_of_indices);
2705     }
2706     const Shape& start_indices_shape = start_index_shapes[0];
2707     TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2708                                    "start indices of dynamic update slice"));
2709 
2710     VLOG(2) << StrFormat(
2711         "updating slice of shape %s at dynamic start_indices %s with update "
2712         "shape %s",
2713         ShapeUtil::HumanString(operand_shape),
2714         ShapeUtil::HumanString(start_indices_shape),
2715         ShapeUtil::HumanString(update_shape));
2716 
2717     if (start_indices_shape.rank() != 1) {
2718       return InvalidArgument(
2719           "Dynamic update slice start indices of rank %d must be rank1.",
2720           start_indices_shape.rank());
2721     }
2722 
2723     if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2724       return InvalidArgument(
2725           "Dynamic update slice start indices must be of integral type.");
2726     }
2727 
2728     const int64_t start_num_dims = start_indices_shape.dimensions(0);
2729     if (operand_shape.rank() != start_num_dims) {
2730       return InvalidArgument(
2731           "Dynamic update slice start number of dimensions %d (%s) must match "
2732           "rank %d of slice input (%s).",
2733           start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2734           operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2735     }
2736   } else {
2737     VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2738                          ShapeUtil::HumanString(operand_shape),
2739                          ShapeUtil::HumanString(update_shape));
2740 
2741     if (operand_shape.rank() != number_of_indices) {
2742       return InvalidArgument(
2743           "Dynamic update slice start number of dimensions %d must match "
2744           "rank %d of slice input (%s).",
2745           number_of_indices, operand_shape.rank(),
2746           ShapeUtil::HumanString(operand_shape));
2747     }
2748 
2749     if (number_of_indices > 0) {
2750       const Shape& first_index_shape = start_index_shapes[0];
2751       if (!ShapeUtil::IsScalar(first_index_shape)) {
2752         return InvalidArgument(
2753             "Dynamic update slice indices must be scalar, not %s.",
2754             ShapeUtil::HumanString(first_index_shape));
2755       }
2756       if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2757         return InvalidArgument(
2758             "Dynamic update slice start indices must be of integral type.");
2759       }
2760       for (const Shape& index_shape : start_index_shapes) {
2761         if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2762           return InvalidArgument(
2763               "Dynamic update slice start indices must all have the same "
2764               "shape, got mismatching indices with shapes %s and %s.",
2765               ShapeUtil::HumanString(first_index_shape),
2766               ShapeUtil::HumanString(index_shape));
2767         }
2768       }
2769     }
2770   }
2771 
2772   if (update_shape.rank() != operand_shape.rank()) {
2773     return InvalidArgument(
2774         "Dynamic update slice update rank does not match argument rank: "
2775         "%d vs %d.",
2776         update_shape.rank(), operand_shape.rank());
2777   }
2778 
2779   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2780                                                      update_shape)) {
2781     return InvalidArgument(
2782         "Dynamic update slice update element type does not match argument. "
2783         "operand.element_type: %s vs update.element_type: %s.",
2784         PrimitiveType_Name(operand_shape.element_type()),
2785         PrimitiveType_Name(update_shape.element_type()));
2786   }
2787 
2788   for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) {
2789     const int64_t input_dim_size = operand_shape.dimensions(dim);
2790     const int64_t update_dim_size = update_shape.dimensions(dim);
2791     if (update_dim_size < 0) {
2792       return InvalidArgument(
2793           "Size index %d to dynamic update slice must be >= 0.",
2794           update_dim_size);
2795     }
2796     if (update_dim_size > input_dim_size) {
2797       return InvalidArgument(
2798           "Update dim size %d greater than dynamic slice dimension: %d.",
2799           update_dim_size, input_dim_size);
2800     }
2801     VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2802   }
2803 
2804   auto result_shape = operand_shape;
2805 
2806   // If any of the operand shape is dynamic, the result dimension is also
2807   // dynamic.
2808   // If update shape is dynamic, only propagate dynamic dimension to result if
2809   // the update is a full update (update_shape[i] == operand_shape[i]).
2810   for (int64_t i = 0; i < update_shape.rank(); ++i) {
2811     if (operand_shape.is_dynamic_dimension(i)) {
2812       result_shape.set_dynamic_dimension(i, true);
2813     }
2814 
2815     if (update_shape.is_dynamic_dimension(i) &&
2816         update_shape.dimensions(i) == operand_shape.dimensions(i)) {
2817       // When update/replace a full dimension, propagate dynamic dimension to
2818       // the result.
2819       result_shape.set_dynamic_dimension(i, true);
2820     }
2821   }
2822 
2823   return result_shape;
2824 }
2825 
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2826 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2827     const Shape& operand_shape, absl::Span<const int64> dimensions) {
2828   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2829   if (!AllUnique(dimensions)) {
2830     return InvalidArgument("a dimension number is duplicated in reverse");
2831   }
2832   for (int64_t dimension : dimensions) {
2833     if (dimension >= operand_shape.rank() || dimension < 0) {
2834       return InvalidArgument(
2835           "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2836           dimension, ShapeUtil::HumanString(operand_shape));
2837     }
2838   }
2839   return operand_shape;
2840 }
2841 
InferGetTupleElementShape(const Shape & arg,int64_t index)2842 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2843     const Shape& arg, int64_t index) {
2844   if (!arg.IsTuple()) {
2845     return InvalidArgument(
2846         "Cannot infer shape: attempting to index into non-tuple: %s.",
2847         ShapeUtil::HumanString(arg));
2848   }
2849 
2850   if (index < 0 || index >= arg.tuple_shapes_size()) {
2851     return InvalidArgument(
2852         "Cannot infer shape: attempt to index out of tuple bounds: %d "
2853         ">= %d in shape %s.",
2854         index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2855   }
2856 
2857   return arg.tuple_shapes(index);
2858 }
2859 
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2860 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2861     const ProgramShape& condition, const ProgramShape& body,
2862     const Shape& init) {
2863   // Check the number of parameters for given computations.
2864   if (condition.parameters_size() != 1) {
2865     return InvalidArgument("Condition must take 1 arguments; got %d.",
2866                            condition.parameters_size());
2867   }
2868   if (body.parameters_size() != 1) {
2869     return InvalidArgument("Body must take 1 arguments; got %d.",
2870                            body.parameters_size());
2871   }
2872 
2873   auto shape_string = [&]() {
2874     return StrFormat(
2875         "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2876         ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2877   };
2878 
2879   // Check the shapes of computation parameters and return types.
2880   if (!ShapeUtil::Compatible(condition.result(),
2881                              ShapeUtil::MakeShape(PRED, {}))) {
2882     return InvalidArgument("Condition must return a boolean; got %s.",
2883                            shape_string());
2884   }
2885   if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2886       !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2887       !ShapeUtil::Compatible(body.result(), init)) {
2888     return InvalidArgument(
2889         "The parameter of condition and body, the result of the body, and init "
2890         "must all have the same shape; got %s.",
2891         shape_string());
2892   }
2893 
2894   return init;
2895 }
2896 
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2897 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2898     const Shape& branch_index,
2899     absl::Span<const ProgramShape> branch_computations,
2900     absl::Span<const Shape> branch_operands) {
2901   if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2902       !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2903     return InvalidArgument("branch_index must be bool or int32; got %s.",
2904                            ShapeUtil::HumanString(branch_index));
2905   }
2906   if (branch_index.element_type() == PRED) {
2907     TF_RET_CHECK(2 == branch_computations.size());
2908   } else {
2909     TF_RET_CHECK(!branch_computations.empty());
2910   }
2911   TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2912   Shape result = branch_computations[0].result();
2913   for (int j = 0; j < branch_computations.size(); ++j) {
2914     if (branch_computations[j].parameters_size() != 1) {
2915       return InvalidArgument(
2916           "branch computation %d must take 1 argument; got %d.", j,
2917           branch_computations[j].parameters_size());
2918     }
2919     if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2920                                branch_operands[j])) {
2921       auto shape_string = [&]() {
2922         return StrFormat("operand: %s; computation: %s",
2923                          ShapeUtil::HumanString(branch_operands[j]),
2924                          ShapeUtil::HumanString(branch_computations[j]));
2925       };
2926       return InvalidArgument(
2927           "branch operand %d must match the shape of the only parameter of "
2928           "branch computation %d: got %s.",
2929           j, j, shape_string());
2930     }
2931 
2932     if (!ShapeUtil::Compatible(branch_computations[0].result(),
2933                                branch_computations[j].result())) {
2934       auto shape_string = [&]() {
2935         return StrFormat(
2936             "branch 0 computation result: %s; branch %d computation result: %s",
2937             ShapeUtil::HumanString(branch_computations[0].result()), j,
2938             ShapeUtil::HumanString(branch_computations[j].result()));
2939       };
2940       return InvalidArgument(
2941           "the result of branch 0 computation and branch %d computation must "
2942           "have the same shape: got %s.",
2943           j, shape_string());
2944     }
2945   }
2946   // For each subshape, If any of the branch is dynamic, we say result is
2947   // dynamic:
2948   //
2949   //   true_branch  (s32[<=4])
2950   //   false_branch (s32[4])
2951   //
2952   // Result is s32[<=4].
2953   ShapeUtil::ForEachMutableSubshape(
2954       &result, [&](Shape* subshape, const ShapeIndex& index) {
2955         if (!subshape->IsArray()) {
2956           return;
2957         }
2958         for (int j = 0; j < branch_computations.size(); ++j) {
2959           auto branch_subshape =
2960               ShapeUtil::GetSubshape(branch_computations[j].result(), index);
2961           for (int64 i = 0; i < branch_subshape.rank(); ++i) {
2962             if (branch_subshape.is_dynamic_dimension(i)) {
2963               subshape->set_dynamic_dimension(i, true);
2964             }
2965           }
2966         }
2967       });
2968 
2969   return result;
2970 }
2971 
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2972 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2973     const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2974   TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2975   for (int64_t size : broadcast_sizes) {
2976     if (size < 0) {
2977       return InvalidArgument("Broadcast with negative dimension size %d.",
2978                              size);
2979     }
2980   }
2981 
2982   std::vector<int64> dimensions(operand.dimensions_size() +
2983                                 broadcast_sizes.size());
2984   std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2985   std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2986             dimensions.begin() + broadcast_sizes.size());
2987 
2988   Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2989   for (int64_t i = 0; i < operand.dimensions_size(); ++i) {
2990     result.set_dynamic_dimension(broadcast_sizes.size() + i,
2991                                  operand.is_dynamic_dimension(i));
2992   }
2993   return result;
2994 }
2995 
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2996 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2997     const Shape& operand_shape, const Shape& output_shape,
2998     absl::Span<const int64> broadcast_dimensions) {
2999   TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
3000   TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
3001   const int64_t operand_rank = operand_shape.rank();
3002   const int64_t output_rank = output_shape.rank();
3003   if (operand_rank > output_rank) {
3004     return InvalidArgument(
3005         "InDim style broadcast must be to an equal or higher ranked shape; "
3006         "operand rank: %lld; output rank: %lld",
3007         operand_rank, output_rank);
3008   }
3009   if (operand_rank != broadcast_dimensions.size()) {
3010     return InvalidArgument(
3011         "Size of broadcast_dimensions has to match operand's rank; operand "
3012         "rank: %lld, size of broadcast_dimensions %u.",
3013         operand_rank, broadcast_dimensions.size());
3014   }
3015   for (int64_t i = 0; i < operand_rank; i++) {
3016     if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
3017       return InvalidArgument("Broadcast dimension %lld is out of bound",
3018                              broadcast_dimensions[i]);
3019     }
3020     if (operand_shape.dimensions(i) !=
3021             output_shape.dimensions(broadcast_dimensions[i]) &&
3022         operand_shape.dimensions(i) != 1) {
3023       return InvalidArgument(
3024           "Input dimension should be either 1 or equal to the output dimension "
3025           "it is broadcasting into; the %lldth operand dimension is %lld, the "
3026           "%lldth output dimension is %lld.",
3027           i, operand_shape.dimensions(i), broadcast_dimensions[i],
3028           output_shape.dimensions(broadcast_dimensions[i]));
3029     }
3030     if (operand_shape.is_dynamic_dimension(i) !=
3031         output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
3032       return InvalidArgument(
3033           "Broadcast input and output dynamism mismatch: %s and %s",
3034           operand_shape.ToString(), output_shape.ToString());
3035     }
3036     // Make sure the broadcast dimensions are listed in a strictly increasing
3037     // order.
3038     if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
3039       return InvalidArgument(
3040           "Broadcast dimensions order is wrong: %d comes after %d.",
3041           broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
3042     }
3043   }
3044 
3045   return output_shape;
3046 }
3047 
InferDynamicReshapeShape(const Shape & operand,absl::Span<const Shape * const> dim_size_shapes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)3048 /* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
3049     const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
3050     absl::Span<const int64> new_size_bounds,
3051     const std::vector<bool>& dims_are_dynamic) {
3052   if (new_size_bounds.size() != dims_are_dynamic.size()) {
3053     return InvalidArgument(
3054         "DynamicReshape has to have the same number of elements in new_sizes "
3055         "(%d) and dims_are_dynamic (%d)",
3056         new_size_bounds.size(), dims_are_dynamic.size());
3057   }
3058 
3059   for (const Shape* dim_size_shape : dim_size_shapes) {
3060     if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
3061       return InvalidArgument(
3062           "DynamicReshape's dim size has to be scalar S32, got (%s): ",
3063           dim_size_shape->ToString());
3064     }
3065   }
3066 
3067   Shape inferred_shape = ShapeUtil::MakeShape(
3068       operand.element_type(), new_size_bounds, dims_are_dynamic);
3069   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3070     return InvalidArgument(
3071         "Reshape operation has mismatched element counts: from=%d (%s) "
3072         "to=%d (%s).",
3073         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3074         ShapeUtil::ElementsIn(inferred_shape),
3075         ShapeUtil::HumanString(inferred_shape));
3076   }
3077   return inferred_shape;
3078 }
3079 
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64_t inferred_dimension)3080 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
3081     const Shape& operand, absl::Span<const int64> dimensions,
3082     absl::Span<const int64> new_sizes, int64_t inferred_dimension) {
3083   TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
3084 
3085   Shape inferred_shape =
3086       ShapeUtil::MakeShape(operand.element_type(), new_sizes);
3087   VLOG(3) << "Reshape inferred shape: "
3088           << ShapeUtil::HumanString(inferred_shape);
3089 
3090   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3091     return InvalidArgument(
3092         "Reshape operation has mismatched element counts: from=%d (%s) "
3093         "to=%d (%s).",
3094         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3095         ShapeUtil::ElementsIn(inferred_shape),
3096         ShapeUtil::HumanString(inferred_shape));
3097   }
3098 
3099   std::vector<int64> indices(operand.rank());
3100   std::iota(indices.begin(), indices.end(), 0);
3101   if (dimensions.size() != operand.rank() ||
3102       !std::is_permutation(dimensions.begin(), dimensions.end(),
3103                            indices.begin())) {
3104     return InvalidArgument(
3105         "Reshape dimensions [%s] are not a permutation of the operand "
3106         "dimensions (operand shape is %s).",
3107         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3108   }
3109 
3110   // Propagate dynamic dimension.
3111   auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
3112   for (int64_t input_dim = 0; input_dim < operand.rank(); ++input_dim) {
3113     if (!operand.is_dynamic_dimension(input_dim)) {
3114       continue;
3115     }
3116 
3117     string reshape_debug_str = absl::StrFormat(
3118         "output: %s, input: %s, input_dim: "
3119         "%lld",
3120         ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
3121         input_dim);
3122 
3123     int64_t input_dim_start = -1;
3124     int64_t input_dim_end = -1;
3125     int64_t output_dim_start = -1;
3126     int64_t output_dim_end = -1;
3127     // Find common_factors that the input_dim belongs to.
3128     for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
3129       auto start = common_factors[i];
3130       auto end = common_factors[i + 1];
3131       if (input_dim >= start.first && input_dim < end.first) {
3132         input_dim_start = start.first;
3133         input_dim_end = end.first;
3134         output_dim_start = start.second;
3135         output_dim_end = end.second;
3136         break;
3137       }
3138     }
3139     if ((input_dim_end - input_dim_start) > 1 &&
3140         (output_dim_end - output_dim_start) > 1) {
3141       // We don't support the case when a dynamic dimension is both combined
3142       // with and splitted into other dimensions:
3143       //
3144       //  [x, yz]
3145       //     | Reshape
3146       //  [xy, z]
3147       return Unimplemented(
3148           "Dynamic input dimension to reshape that is both splitted and "
3149           "combined is not supported: %s",
3150           reshape_debug_str);
3151     }
3152 
3153     for (auto common_factor : common_factors) {
3154       //
3155       // For reshapes like:
3156       //  [<=5]
3157       //    | Reshape
3158       //  [1, 5]
3159       //
3160       //  The input dynamic dimension can go into either dynamic dimensions.
3161       //  However, the return value of common factors only returns
3162       //  input: 5
3163       //  output: 5
3164       //
3165       //  We need to expand common factor to include degenerated output
3166       //  dimensions:
3167       //  input: 5
3168       //  output: 1, 5
3169       //
3170       //  such that in the logic later on we can consider both dimensions as
3171       //  candidate.
3172       if (common_factor.first == input_dim_start) {
3173         output_dim_start = std::min(output_dim_start, common_factor.second);
3174       }
3175       if (common_factor.first == input_dim_end) {
3176         output_dim_end = std::max(output_dim_end, common_factor.second);
3177       }
3178     }
3179 
3180     // Calculate output dynamic reshape dimension.
3181     int64_t output_dynamic_dimension = -1;
3182 
3183     if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
3184       // If dynamic dimension is size 1, it can only be most-major or
3185       // most-minor.
3186       if (input_dim == 0) {
3187         output_dynamic_dimension = 0;
3188       }
3189       if (input_dim == operand.rank() - 1) {
3190         output_dynamic_dimension = new_sizes.size() - 1;
3191       }
3192 
3193       if (output_dynamic_dimension == -1) {
3194         return Unimplemented(
3195             "Dynamic degenerated dimension that's not most-minor nor "
3196             "most-major is not supported: %s",
3197             reshape_debug_str);
3198       }
3199     }
3200 
3201     if (output_dynamic_dimension == -1 &&
3202         output_dim_end - output_dim_start == 1) {
3203       // Only one possible output dimension.
3204       output_dynamic_dimension = output_dim_start;
3205     }
3206     if (output_dynamic_dimension == -1 &&
3207         output_dim_end - output_dim_start > 1) {
3208       // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
3209       output_dynamic_dimension = inferred_dimension;
3210     }
3211 
3212     if (output_dynamic_dimension != -1) {
3213       // TODO(yunxing): Turn this into a CHECK.
3214       inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
3215     } else {
3216       std::vector<int64> output_non_degenerated;
3217       for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
3218         if (new_sizes[i] != 1) {
3219           output_non_degenerated.push_back(i);
3220         }
3221       }
3222       if (output_non_degenerated.size() == 1) {
3223         inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
3224       }
3225     }
3226   }
3227 
3228   return inferred_shape;
3229 }
3230 
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)3231 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
3232     const Shape& operand, absl::Span<const int64> dimensions) {
3233   TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
3234 
3235   if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
3236     return InvalidArgument(
3237         "Transpose dimensions [%s] are not a permutation of the operand "
3238         "dimensions (operand shape is %s).",
3239         StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3240   }
3241 
3242   // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
3243   // we need output[i]=input[dimensions[i]] which is
3244   // Permute(Inverse(dimensions),input).
3245   return ShapeUtil::PermuteDimensions(dimensions, operand);
3246 }
3247 
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)3248 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
3249     const Shape& min, const Shape& operand, const Shape& max) {
3250   TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
3251   TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
3252   TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
3253 
3254   if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
3255       !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
3256     return InvalidArgument(
3257         "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
3258         ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
3259   }
3260   return operand;
3261 }
3262 
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3263 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
3264     const Shape& pred, const Shape& on_true, const Shape& on_false) {
3265   TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
3266   TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
3267   TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
3268 
3269   if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
3270     return InvalidArgument(
3271         "Operands to select must be the same shape; got %s and %s.",
3272         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3273   }
3274   if (pred.element_type() != PRED) {
3275     return InvalidArgument(
3276         "Select's pred operand must have PRED element type; got %s.",
3277         ShapeUtil::HumanString(pred));
3278   }
3279   if (!Shape::Equal()
3280            .IgnoreElementType()
3281            .IgnoreLayout()
3282            .IgnoreDynamicDimension()(pred, on_true)) {
3283     return InvalidArgument(
3284         "Operands to select and predicate must be the same shape; got %s and "
3285         "%s.",
3286         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3287   }
3288 
3289   return ShapeUtil::ChangeElementType(
3290       pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3291 }
3292 
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3293 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
3294     const Shape& pred, const Shape& on_true, const Shape& on_false) {
3295   // Select only defines the top-level buffer, so if it's a tuple, the two
3296   // input must match exactly.
3297   if (!ShapeUtil::Compatible(on_true, on_false)) {
3298     return InvalidArgument(
3299         "Operands to tuple-select must be the same shape; got %s and %s.",
3300         ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3301   }
3302   if (pred.element_type() != PRED) {
3303     return InvalidArgument(
3304         "TupleSelect's pred operand must have PRED element type; got %s.",
3305         ShapeUtil::HumanString(pred));
3306   }
3307   if (!ShapeUtil::IsScalar(pred)) {
3308     return InvalidArgument(
3309         "TupleSelect operation with non-scalar predicate: %s.",
3310         ShapeUtil::HumanString(pred));
3311   }
3312   return on_true;
3313 }
3314 
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3315 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3316     absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3317   // The applied function's arity equals the number of arguments.
3318   if (arg_shapes.size() != to_apply.parameters_size()) {
3319     string computation_signature = ShapeUtil::HumanString(to_apply);
3320     string argument_shapes =
3321         StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
3322           absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3323         });
3324     return InvalidArgument(
3325         "Call applied function arity must match number of arguments; got: "
3326         "arity: %d, arguments: %u; computation signature: %s; argument "
3327         "shapes: [%s].",
3328         to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3329         argument_shapes);
3330   }
3331 
3332   // All arguments must be compatible with the program shape.
3333   for (int i = 0; i < arg_shapes.size(); ++i) {
3334     const Shape& arg_shape = *arg_shapes[i];
3335     const Shape& param_shape = to_apply.parameters(i);
3336     if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3337       return InvalidArgument(
3338           "Call parameter must match argument; got parameter %d shape: %s, "
3339           "argument shape: %s.",
3340           i, ShapeUtil::HumanString(param_shape),
3341           ShapeUtil::HumanString(arg_shape));
3342     }
3343   }
3344 
3345   return to_apply.result();
3346 }
3347 
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3348 static Status ValidateGatherDimensionNumbers(
3349     const Shape& input_shape, absl::Span<const int64> start_indices_shape,
3350     const GatherDimensionNumbers& dim_numbers) {
3351   if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3352     return InvalidArgument(
3353         "Output window dimensions in gather op must be ascending; got: %s.",
3354         StrJoin(dim_numbers.offset_dims(), ", "));
3355   }
3356 
3357   if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3358       dim_numbers.offset_dims().end()) {
3359     return InvalidArgument(
3360         "Output window dimensions in gather op must not repeat; got: %s.",
3361         StrJoin(dim_numbers.offset_dims(), ", "));
3362   }
3363 
3364   const int64_t output_offset_dim_count = dim_numbers.offset_dims_size();
3365   const int64_t output_shape_rank =
3366       output_offset_dim_count + start_indices_shape.size() - 1;
3367 
3368   for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3369     int64_t offset_dim = dim_numbers.offset_dims(i);
3370     if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3371       return InvalidArgument(
3372           "Offset dimension %d in gather op is out of bounds; got %d, but "
3373           "should "
3374           "have been in [0,%d).",
3375           i, offset_dim, output_shape_rank);
3376     }
3377   }
3378 
3379   if (dim_numbers.start_index_map_size() !=
3380       start_indices_shape[dim_numbers.index_vector_dim()]) {
3381     return InvalidArgument(
3382         "Gather op has %d elements in start_index_map and the "
3383         "bound of dimension index_vector_dim=%d of start_indices is "
3384         "%d. These two numbers must be equal.",
3385         dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3386         start_indices_shape[dim_numbers.index_vector_dim()]);
3387   }
3388 
3389   for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3390     int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3391     if (operand_dim_for_start_index_i < 0 ||
3392         operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3393       return InvalidArgument(
3394           "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3395           input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3396     }
3397   }
3398 
3399   std::vector<int64> sorted_start_index_map(
3400       dim_numbers.start_index_map().begin(),
3401       dim_numbers.start_index_map().end());
3402 
3403   absl::c_sort(sorted_start_index_map);
3404 
3405   if (absl::c_adjacent_find(sorted_start_index_map) !=
3406       sorted_start_index_map.end()) {
3407     return InvalidArgument(
3408         "Repeated dimensions are not allowed in start_index_map; "
3409         "got: %s.",
3410         StrJoin(dim_numbers.start_index_map(), ", "));
3411   }
3412 
3413   for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3414     if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3415       return InvalidArgument(
3416           "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3417           "%d), got: %d.",
3418           input_shape.dimensions_size(), collapsed_dim);
3419     }
3420   }
3421 
3422   if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3423     return InvalidArgument(
3424         "collapsed_slice_dims in gather op must be sorted; got: %s",
3425         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3426   }
3427 
3428   if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3429       dim_numbers.collapsed_slice_dims().end()) {
3430     return InvalidArgument(
3431         "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3432         "got: %s.",
3433         StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3434   }
3435 
3436   return Status::OK();
3437 }
3438 
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)3439 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3440     const Shape& input_shape, const Shape& start_indices_shape,
3441     const GatherDimensionNumbers& gather_dim_numbers,
3442     absl::Span<const int64> slice_sizes) {
3443   TF_RETURN_IF_ERROR(
3444       ExpectArray(input_shape, "input tensor operand of gather op"));
3445   TF_RETURN_IF_ERROR(
3446       ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3447 
3448   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3449     return InvalidArgument(
3450         "Gather indices parameter must be an integral tensor; got %s.",
3451         ShapeUtil::HumanString(start_indices_shape));
3452   }
3453 
3454   // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3455   // index_vector_dim is rank(P).  The bounds of this expanded shape is
3456   // stored in expanded_start_indices_shape.
3457 
3458   if (start_indices_shape.dimensions_size() <
3459           gather_dim_numbers.index_vector_dim() ||
3460       gather_dim_numbers.index_vector_dim() < 0) {
3461     return InvalidArgument(
3462         "Gather index leaf dimension must be within [0, rank(start_indices) + "
3463         "1). rank(start_indices) is %d and gather index leaf dimension is "
3464         "%d.",
3465         start_indices_shape.dimensions_size(),
3466         gather_dim_numbers.index_vector_dim());
3467   }
3468 
3469   std::vector<int64> expanded_start_indices_shape;
3470   // Also tracks if an output dimension is dynamic.
3471   std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3472   expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3473   expanded_start_indices_shape_dynamic_dimensions.reserve(
3474       start_indices_shape.dimensions_size());
3475   absl::c_copy(start_indices_shape.dimensions(),
3476                std::back_inserter(expanded_start_indices_shape));
3477   absl::c_copy(
3478       start_indices_shape.dynamic_dimensions(),
3479       std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3480   if (expanded_start_indices_shape.size() ==
3481       gather_dim_numbers.index_vector_dim()) {
3482     expanded_start_indices_shape.push_back(1);
3483     expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3484   }
3485 
3486   TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3487       input_shape, expanded_start_indices_shape, gather_dim_numbers));
3488 
3489   if (slice_sizes.size() != input_shape.dimensions_size()) {
3490     return InvalidArgument(
3491         "Gather op must have one slice size for every input dimension; got: "
3492         "len(slice_sizes)=%lu, input_shape.rank=%d.",
3493         slice_sizes.size(), input_shape.dimensions_size());
3494   }
3495 
3496   if (slice_sizes.size() !=
3497       gather_dim_numbers.offset_dims_size() +
3498           gather_dim_numbers.collapsed_slice_dims_size()) {
3499     return InvalidArgument(
3500         "All components of the offset index in a gather op must either be a "
3501         "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3502         "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3503         slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3504         StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3505   }
3506 
3507   for (int i = 0; i < slice_sizes.size(); i++) {
3508     int64_t slice_size = slice_sizes[i];
3509     int64_t corresponding_input_size = input_shape.dimensions(i);
3510     if (slice_size < 0 || slice_size > corresponding_input_size) {
3511       return InvalidArgument(
3512           "Slice size at index %d in gather op is out of range, must be "
3513           "within [0, %d), got %d.",
3514           i, corresponding_input_size + 1, slice_size);
3515     }
3516   }
3517 
3518   for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3519     if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3520       return InvalidArgument(
3521           "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3522           "is %d for index %d at position %d.",
3523           slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3524           gather_dim_numbers.collapsed_slice_dims(i), i);
3525     }
3526   }
3527 
3528   int64_t result_rank = gather_dim_numbers.offset_dims_size() +
3529                         (expanded_start_indices_shape.size() - 1);
3530   int64_t offset_dims_seen = 0;
3531   int64_t gather_dims_seen = 0;
3532   std::vector<int64> output_dim_bounds;
3533   output_dim_bounds.reserve(result_rank);
3534 
3535   std::vector<bool> output_dim_is_dynamic;
3536   output_dim_is_dynamic.reserve(result_rank);
3537   for (int64_t i = 0; i < result_rank; i++) {
3538     int64_t current_bound;
3539     bool dim_dynamic = false;
3540     bool is_window_index =
3541         absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3542     if (is_window_index) {
3543       while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3544                                    offset_dims_seen)) {
3545         offset_dims_seen++;
3546       }
3547       // Gathering an entire dynamic dimension creates dynamic dimension.
3548       //
3549       // e.g.,:
3550       //
3551       // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3552       //
3553       // creates
3554       //
3555       // [<=2, 1]
3556       if (slice_sizes[offset_dims_seen] ==
3557           input_shape.dimensions(offset_dims_seen)) {
3558         dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3559       }
3560       current_bound = slice_sizes[offset_dims_seen++];
3561     } else {
3562       if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3563         gather_dims_seen++;
3564       }
3565       // Forward dynamic dimensions from indices.
3566       dim_dynamic =
3567           expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3568 
3569       current_bound = expanded_start_indices_shape[gather_dims_seen++];
3570     }
3571     output_dim_is_dynamic.push_back(dim_dynamic);
3572     output_dim_bounds.push_back(current_bound);
3573   }
3574 
3575   return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3576                               output_dim_is_dynamic);
3577 }
3578 
3579 namespace {
3580 
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3581 Status ValidateScatterDimensionNumbers(
3582     const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3583     const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3584   // Validate update_window_dims in ScatterDimensionNumbers.
3585   if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3586     return InvalidArgument(
3587         "update_window_dims in scatter op must be sorted; got: %s.",
3588         StrJoin(dim_numbers.update_window_dims(), ", "));
3589   }
3590   if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3591       dim_numbers.update_window_dims().end()) {
3592     return InvalidArgument(
3593         "update_window_dims in scatter op must not repeat; got: %s.",
3594         StrJoin(dim_numbers.update_window_dims(), ", "));
3595   }
3596   const int64_t updates_rank = updates_shape.rank();
3597   for (int64_t window_dim : dim_numbers.update_window_dims()) {
3598     if (window_dim < 0 || window_dim >= updates_rank) {
3599       return InvalidArgument(
3600           "Invalid update_window_dims set in scatter op; valid range is [0, "
3601           "%d). got: %d.",
3602           updates_rank, window_dim);
3603     }
3604   }
3605 
3606   // Validate inserted_window_dims in ScatterDimensionNumbers.
3607   if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3608     return InvalidArgument(
3609         "inserted_window_dims in scatter op must be sorted; got: %s.",
3610         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3611   }
3612   if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3613       dim_numbers.inserted_window_dims().end()) {
3614     return InvalidArgument(
3615         "inserted_window_dims in scatter op must not repeat; got: %s.",
3616         StrJoin(dim_numbers.inserted_window_dims(), ", "));
3617   }
3618   for (int64_t inserted_dim : dim_numbers.inserted_window_dims()) {
3619     if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3620       return InvalidArgument(
3621           "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3622           "%d), got: %d.",
3623           operand_shape.dimensions_size(), inserted_dim);
3624     }
3625   }
3626 
3627   // Validate window size.
3628   auto window_size = dim_numbers.update_window_dims_size() +
3629                      dim_numbers.inserted_window_dims_size();
3630   if (window_size != operand_shape.rank()) {
3631     return InvalidArgument(
3632         "Scatter op has window of size %d; doesn't match operand of rank %d.",
3633         window_size, operand_shape.rank());
3634   }
3635 
3636   // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3637   if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3638       scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3639     return InvalidArgument(
3640         "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3641         "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3642         "These two numbers must be equal.",
3643         dim_numbers.scatter_dims_to_operand_dims_size(),
3644         dim_numbers.index_vector_dim(),
3645         scatter_indices_shape[dim_numbers.index_vector_dim()]);
3646   }
3647   for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3648     int64_t scatter_dim_to_operand_dim =
3649         dim_numbers.scatter_dims_to_operand_dims(i);
3650     if (scatter_dim_to_operand_dim < 0 ||
3651         scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3652       return InvalidArgument(
3653           "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3654           "got: %d->%d.",
3655           operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3656     }
3657   }
3658   std::vector<int64> sorted_scatter_dims_to_operand_dims(
3659       dim_numbers.scatter_dims_to_operand_dims().begin(),
3660       dim_numbers.scatter_dims_to_operand_dims().end());
3661   absl::c_sort(sorted_scatter_dims_to_operand_dims);
3662   if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3663       sorted_scatter_dims_to_operand_dims.end()) {
3664     return InvalidArgument(
3665         "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3666         "got: %s.",
3667         StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3668   }
3669 
3670   return Status::OK();
3671 }
3672 
3673 }  // namespace
3674 
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3675 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3676     const Shape& operand_shape, const Shape& scatter_indices_shape,
3677     const Shape& updates_shape, const ProgramShape& to_apply_shape,
3678     const ScatterDimensionNumbers& scatter_dim_numbers) {
3679   TF_RETURN_IF_ERROR(
3680       ExpectArray(operand_shape, "operand tensor of scatter op"));
3681   TF_RETURN_IF_ERROR(
3682       ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3683   TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3684 
3685   if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3686     return InvalidArgument(
3687         "Scatter indices parameter must be an integral tensor; got %s.",
3688         ShapeUtil::HumanString(scatter_indices_shape));
3689   }
3690 
3691   if (scatter_indices_shape.dimensions_size() <
3692           scatter_dim_numbers.index_vector_dim() ||
3693       scatter_dim_numbers.index_vector_dim() < 0) {
3694     return InvalidArgument(
3695         "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3696         " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3697         "is %d.",
3698         scatter_indices_shape.dimensions_size(),
3699         scatter_dim_numbers.index_vector_dim());
3700   }
3701 
3702   // Check if the update computation has a proper shape as a reduction.
3703   const Shape init_value_shape =
3704       ShapeUtil::MakeShape(operand_shape.element_type(), {});
3705   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3706                                         {updates_shape.element_type()},
3707                                         /*inputs=*/1));
3708 
3709   std::vector<int64> expanded_scatter_indices_shape =
3710       SpanToVector(scatter_indices_shape.dimensions());
3711   if (expanded_scatter_indices_shape.size() ==
3712       scatter_dim_numbers.index_vector_dim()) {
3713     expanded_scatter_indices_shape.push_back(1);
3714   }
3715 
3716   int64_t expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3717                                   scatter_dim_numbers.update_window_dims_size();
3718   if (updates_shape.rank() != expected_updates_rank) {
3719     return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3720                            expected_updates_rank, updates_shape.rank());
3721   }
3722 
3723   TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3724       operand_shape, expanded_scatter_indices_shape, updates_shape,
3725       scatter_dim_numbers));
3726 
3727   int64_t inserted_dims_seen = 0;
3728   std::vector<int64> max_update_slice_sizes;
3729   for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3730     if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3731         scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3732       ++inserted_dims_seen;
3733     } else {
3734       max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3735     }
3736   }
3737   for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3738     auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3739     if (updates_shape.dimensions(update_window_dim) >
3740         max_update_slice_sizes[i]) {
3741       return InvalidArgument(
3742           "Bounds of the window dimensions of updates must not exceed the "
3743           "bounds of the corresponding dimensions of operand. For dimension "
3744           "%d, updates bound is %d, operand bound is %d.",
3745           update_window_dim, updates_shape.dimensions(update_window_dim),
3746           max_update_slice_sizes[i]);
3747     }
3748   }
3749 
3750   int64_t scatter_dims_seen = 0;
3751   for (int64_t i = 0; i < updates_shape.rank(); ++i) {
3752     bool is_update_window_dim =
3753         absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3754     if (is_update_window_dim) {
3755       continue;
3756     }
3757     if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3758       ++scatter_dims_seen;
3759     }
3760     if (updates_shape.dimensions(i) !=
3761         expanded_scatter_indices_shape[scatter_dims_seen]) {
3762       return InvalidArgument(
3763           "Bounds of the scatter dimensions of updates must be same as the "
3764           "bounds of the corresponding dimensions of scatter indices. For "
3765           "scatter dimension %d, updates bound is %d, scatter_indices "
3766           "bound is %d.",
3767           i, updates_shape.dimensions(i),
3768           expanded_scatter_indices_shape[scatter_dims_seen]);
3769     }
3770     ++scatter_dims_seen;
3771   }
3772 
3773   return operand_shape;
3774 }
3775 
3776 }  // namespace xla
3777