• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <stddef.h>
19 #include <algorithm>
20 #include <numeric>
21 #include <set>
22 #include <string>
23 
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/flatset.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/lib/strings/stringprintf.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/protobuf.h"
39 
40 using tensorflow::str_util::Join;
41 using tensorflow::strings::Printf;
42 
43 namespace xla {
44 
45 namespace {
46 
47 // Return the UnaryOperation proto enum value associated with the given HLO
48 // opcode.
OpcodeToUnaryOperation(HloOpcode opcode)49 UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
50   switch (opcode) {
51     case HloOpcode::kAbs:
52       return UNOP_ABS;
53     case HloOpcode::kCeil:
54       return UNOP_CEIL;
55     case HloOpcode::kCos:
56       return UNOP_COS;
57     case HloOpcode::kExp:
58       return UNOP_EXP;
59     case HloOpcode::kFloor:
60       return UNOP_FLOOR;
61     case HloOpcode::kImag:
62       return UNOP_IMAG;
63     case HloOpcode::kIsFinite:
64       return UNOP_IS_FINITE;
65     case HloOpcode::kLog:
66       return UNOP_LOG;
67     case HloOpcode::kNot:
68       return UNOP_NOT;
69     case HloOpcode::kNegate:
70       return UNOP_NEGATE;
71     case HloOpcode::kReal:
72       return UNOP_REAL;
73     case HloOpcode::kRoundNearestAfz:
74       return UNOP_ROUND_NEAREST_AFZ;
75     case HloOpcode::kSign:
76       return UNOP_SIGN;
77     case HloOpcode::kSin:
78       return UNOP_SIN;
79     case HloOpcode::kSort:
80       return UNOP_SORT;
81     case HloOpcode::kTanh:
82       return UNOP_TANH;
83     default:
84       LOG(FATAL) << "Unhandled opcode for conversion to unary operation: "
85                  << opcode;
86   }
87 }
88 
89 // Return the BinaryOperation proto enum value associated with the given HLO
90 // opcode.
OpcodeToBinaryOperation(HloOpcode opcode)91 BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
92   switch (opcode) {
93     case HloOpcode::kAtan2:
94       return BINOP_ATAN2;
95     case HloOpcode::kComplex:
96       return BINOP_COMPLEX;
97     case HloOpcode::kMultiply:
98       return BINOP_MUL;
99     case HloOpcode::kAdd:
100       return BINOP_ADD;
101     case HloOpcode::kSubtract:
102       return BINOP_SUB;
103     case HloOpcode::kDivide:
104       return BINOP_DIV;
105     case HloOpcode::kEq:
106       return BINOP_EQ;
107     case HloOpcode::kGe:
108       return BINOP_GE;
109     case HloOpcode::kGt:
110       return BINOP_GT;
111     case HloOpcode::kLe:
112       return BINOP_LE;
113     case HloOpcode::kLt:
114       return BINOP_LT;
115     case HloOpcode::kNe:
116       return BINOP_NE;
117     case HloOpcode::kMaximum:
118       return BINOP_MAX;
119     case HloOpcode::kMinimum:
120       return BINOP_MIN;
121     case HloOpcode::kPower:
122       return BINOP_POW;
123     case HloOpcode::kRemainder:
124       return BINOP_REM;
125     case HloOpcode::kOr:
126       return BINOP_OR;
127     case HloOpcode::kAnd:
128       return BINOP_AND;
129     case HloOpcode::kShiftLeft:
130       return BINOP_SHIFT_LEFT;
131     case HloOpcode::kShiftRightArithmetic:
132       return BINOP_SHIFT_RIGHT_ARITHMETIC;
133     case HloOpcode::kShiftRightLogical:
134       return BINOP_SHIFT_RIGHT_LOGICAL;
135     default:
136       LOG(FATAL) << "unhandled opcode " << opcode;
137   }
138 }
139 
140 // Return the TernaryOperation proto enum value associated with the given HLO
141 // opcode.
OpcodeToTernaryOperation(HloOpcode opcode)142 TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) {
143   switch (opcode) {
144     case HloOpcode::kClamp:
145       return TRIOP_CLAMP;
146     case HloOpcode::kSelect:
147       return TRIOP_SELECT;
148     default:
149       LOG(FATAL) << "unhandled opcode " << opcode;
150   }
151 }
152 
153 // Return the VariadicOperation proto enum value associated with the given HLO
154 // opcode.
OpcodeToVariadicOperation(HloOpcode opcode)155 VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) {
156   switch (opcode) {
157     case HloOpcode::kTuple:
158       return VAROP_TUPLE;
159     default:
160       LOG(FATAL) << "unhandled opcode " << opcode;
161   }
162 }
163 
164 // Returns true if no element is present in slice more than once.
AllUnique(tensorflow::gtl::ArraySlice<int64> slice)165 bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
166   return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
167 }
168 
ExpectNotTupleOrOpaque(const Shape & shape,tensorflow::StringPiece op_type)169 tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
170                                           tensorflow::StringPiece op_type) {
171   if (ShapeUtil::IsTuple(shape)) {
172     return InvalidArgument("Expected non-tuple argument for %s. Got: %s",
173                            op_type.ToString().c_str(),
174                            ShapeUtil::HumanString(shape).c_str());
175   } else if (ShapeUtil::IsOpaque(shape)) {
176     return InvalidArgument("Expected non-opaque argument for %s. Got: %s",
177                            op_type.ToString().c_str(),
178                            ShapeUtil::HumanString(shape).c_str());
179   } else {
180     return tensorflow::Status::OK();
181   }
182 }
183 
VerifyReducerShape(const ProgramShape & reducer_shape,const Shape & init_value_shape,const PrimitiveType & input_element_type)184 tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
185                                       const Shape& init_value_shape,
186                                       const PrimitiveType& input_element_type) {
187   if (reducer_shape.parameters_size() != 2) {
188     return InvalidArgument(
189         "Reduction function must take 2 parameters, but "
190         "takes %d parameter(s).",
191         reducer_shape.parameters_size());
192   }
193 
194   const Shape& accumulator_shape = reducer_shape.result();
195   if (ShapeUtil::Rank(accumulator_shape) != 0) {
196     return Unimplemented(
197         "Reduction function currently must have rank-0 result.");
198   }
199 
200   // Check that the accumulator can be passed in as the first argument.
201   // Note: comparing here and below with Compatible since we don't care about
202   // layout in scalars - see b/26668201 for a longer-term vision.
203   if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
204     return InvalidArgument(
205         "Reduction function's first parameter shape differs from the "
206         "result shape: %s vs %s",
207         ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
208         ShapeUtil::HumanString(accumulator_shape).c_str());
209   }
210 
211   // Check that init_value's shape is suitable for reducer_shape.
212   if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
213                                                 init_value_shape)) {
214     return InvalidArgument(
215         "Reduction function's accumulator shape differs from the "
216         "init_value shape: %s vs %s",
217         ShapeUtil::HumanString(accumulator_shape).c_str(),
218         ShapeUtil::HumanString(init_value_shape).c_str());
219   }
220 
221   // Check that the inputs can be passed in as the second argument.
222   const Shape& input_element_shape =
223       ShapeUtil::MakeShape(input_element_type, {});
224   if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
225                                                 reducer_shape.parameters(1))) {
226     return InvalidArgument(
227         "Reduction function's second parameter shape differs from the "
228         "input type element type: %s vs %s",
229         ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
230         ShapeUtil::HumanString(input_element_shape).c_str());
231   }
232 
233   // Currently the accumulator and inputs must be the same type,
234   // though that restriction could be relaxed.
235   if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
236                                                 reducer_shape.parameters(1))) {
237     return InvalidArgument(
238         "Reduction function's second parameter shape currently must "
239         "match the result shape. Got %s vs %s",
240         ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
241         ShapeUtil::HumanString(accumulator_shape).c_str());
242   }
243 
244   return tensorflow::Status::OK();
245 }
246 
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)247 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
248                                        const Window& window,
249                                        PrimitiveType element_type,
250                                        bool allow_negative_padding) {
251   if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
252     return InvalidArgument(
253         "Window has dimension %d but base shape has dimension %lld.",
254         window.dimensions_size(), ShapeUtil::Rank(base_shape));
255   }
256 
257   std::vector<int64> output_dimensions(window.dimensions_size());
258   for (int64 i = 0; i < window.dimensions_size(); ++i) {
259     const auto& dim = window.dimensions(i);
260     if (dim.size() <= 0) {
261       return InvalidArgument("Window has a non-positive dimension. Window: %s",
262                              window.DebugString().c_str());
263     }
264     if (dim.stride() <= 0) {
265       return InvalidArgument("Window has a non-positive stride. Window: %s",
266                              window.DebugString().c_str());
267     }
268     if (!allow_negative_padding && dim.padding_low() < 0) {
269       return InvalidArgument("Window has a negative low padding. Window: %s",
270                              window.DebugString().c_str());
271     }
272     if (!allow_negative_padding && dim.padding_high() < 0) {
273       return InvalidArgument("Window has a negative high padding. Window: %s",
274                              window.DebugString().c_str());
275     }
276     if (dim.base_dilation() < 1) {
277       return InvalidArgument(
278           "Window has a non-positive base area dilation factor. Window: %s",
279           window.DebugString().c_str());
280     }
281     if (dim.window_dilation() < 1) {
282       return InvalidArgument(
283           "Window has a non-positive window dilation factor. Window: %s",
284           window.DebugString().c_str());
285     }
286 
287     const int64 dilated_base = window_util::DilatedBound(
288         ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
289     const int64 padded_dilated_base =
290         dim.padding_low() + dilated_base + dim.padding_high();
291     const int64 dilated_window =
292         window_util::DilatedBound(dim.size(), dim.window_dilation());
293 
294     output_dimensions[i] = window_util::StridedBound(
295         padded_dilated_base, dilated_window, dim.stride());
296   }
297 
298   return ShapeUtil::MakeShape(element_type, output_dimensions);
299 }
300 
301 }  // namespace
302 
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)303 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
304     HloOpcode opcode, const HloInstruction* operand) {
305   // There is no copy operation at the proto level, so handle copy explicitly.
306   if (opcode == HloOpcode::kCopy) {
307     return operand->shape();
308   }
309 
310   return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape());
311 }
312 
InferUnaryOpShape(UnaryOperation operation,const Shape & arg)313 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
314     UnaryOperation operation, const Shape& arg) {
315   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
316 
317   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg));
318   switch (operation) {
319     case UNOP_FLOOR:
320     case UNOP_CEIL:
321       if (!ShapeUtil::ElementIsFloating(arg)) {
322         return InvalidArgument(
323             "expected element type in shape to be floating for floor/ceil "
324             "operation; got %s",
325             PrimitiveType_Name(arg.element_type()).c_str());
326       }
327       return arg;
328     case UNOP_COS:
329     case UNOP_SIN:
330     case UNOP_EXP:
331     case UNOP_LOG:
332     case UNOP_TANH:
333       if (!ShapeUtil::ElementIsFloating(arg) &&
334           !ShapeUtil::ElementIsComplex(arg)) {
335         return InvalidArgument(
336             "expected element type in shape to be floating or complex for "
337             "sin/cos/exp/log/tanh operation; got %s",
338             PrimitiveType_Name(arg.element_type()).c_str());
339       }
340       return arg;
341     case UNOP_REAL:
342     case UNOP_IMAG:
343       if (!ShapeUtil::ElementIsComplex(arg)) {
344         return InvalidArgument(
345             "expected element type in shape to be complex for real/imag "
346             "operation; got %s",
347             PrimitiveType_Name(arg.element_type()).c_str());
348       }
349       return ShapeUtil::ChangeElementType(arg, F32);
350     case UNOP_ABS:
351       if (ShapeUtil::ElementIsComplex(arg)) {
352         return ShapeUtil::ChangeElementType(
353             arg, primitive_util::ComplexComponentType(arg.element_type()));
354       }
355       return arg;
356     case UNOP_NEGATE:
357     case UNOP_ROUND_NEAREST_AFZ:
358     case UNOP_SIGN:
359     case UNOP_SORT:
360       return arg;
361 
362     case UNOP_NOT:
363       if (arg.element_type() != PRED &&
364           !primitive_util::IsIntegralType(arg.element_type())) {
365         return InvalidArgument(
366             "expected pred or an integral element type in argument to not "
367             "operation; got %s",
368             PrimitiveType_Name(arg.element_type()).c_str());
369       }
370       return arg;
371 
372     case UNOP_IS_FINITE:
373       if (!ShapeUtil::ElementIsFloating(arg)) {
374         return InvalidArgument(
375             "expected element type in shape to be floating point for IsFinite "
376             "operation; got %s",
377             PrimitiveType_Name(arg.element_type()).c_str());
378       }
379       return ShapeUtil::ChangeElementType(arg, PRED);
380 
381     default:
382       return InvalidArgument(
383           "Unknown operation for unary shape inference: \"%s\".",
384           UnaryOperation_Name(operation).c_str());
385   }
386 }
387 
InferConcatOpShape(tensorflow::gtl::ArraySlice<const Shape * > arg_shapes,const int64 dimension)388 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
389     tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
390     const int64 dimension) {
391   if (arg_shapes.empty()) {
392     return InvalidArgument("Concatenate expects at least one argument");
393   }
394   if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
395     return InvalidArgument("dimension to concatenate along out of bounds: %lld",
396                            dimension);
397   }
398   const Shape* arg_shape = nullptr;
399   PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
400   for (const Shape* shape : arg_shapes) {
401     TF_RETURN_IF_ERROR(
402         ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
403     if (!arg_shape) {
404       arg_shape = shape;
405       element_type = arg_shape->element_type();
406       continue;
407     }
408     if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
409       return InvalidArgument(
410           "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
411           "(%s)",
412           ShapeUtil::Rank(*arg_shape),
413           ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
414           ShapeUtil::HumanString(*shape).c_str());
415     }
416     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
417       return InvalidArgument(
418           "cannot concatenate arrays with different element types: %s vs %s",
419           PrimitiveType_Name(arg_shape->element_type()).c_str(),
420           PrimitiveType_Name(shape->element_type()).c_str());
421     }
422     for (int64 dimension_number = 0;
423          dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
424       if (arg_shape->dimensions(dimension_number) !=
425           shape->dimensions(dimension_number)) {
426         if (dimension_number == dimension) {
427           continue;  // It's okay to differ in the dimension we're
428                      // concatenating.
429         }
430         return InvalidArgument(
431             "cannot concatenate arrays that differ in dimensions other than "
432             "the one being concatenated (the other array dimensions must be "
433             "the same): %s vs %s in dimension %lld",
434             ShapeUtil::HumanString(*arg_shape).c_str(),
435             ShapeUtil::HumanString(*shape).c_str(), dimension);
436       }
437     }
438     element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
439   }
440 
441   std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
442                                     arg_shape->dimensions().end());
443   for (size_t i = 1; i < arg_shapes.size(); ++i) {
444     new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
445   }
446   return ShapeUtil::MakeShape(element_type, new_dimensions);
447 }
448 
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)449 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
450     const Shape& operand_shape, PrimitiveType new_element_type) {
451   auto old_element_type = operand_shape.element_type();
452   if (primitive_util::IsComplexType(old_element_type) &&
453       !primitive_util::IsComplexType(new_element_type)) {
454     return Unimplemented(
455         "Unsupported conversion from complex to real type: %s => %s",
456         ShapeUtil::HumanString(operand_shape).c_str(),
457         PrimitiveType_Name(new_element_type).c_str());
458   }
459   if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
460     // Note: we may want to support tuple conversions via this operation in the
461     // future, by recursing into the tuple elements to check all sub-conversions
462     // are valid. For now we just reject them, though.
463     return InvalidArgument(
464         "cannot convert from or to tuple type; requested conversion: %s => %s",
465         ShapeUtil::HumanString(operand_shape).c_str(),
466         PrimitiveType_Name(new_element_type).c_str());
467   }
468 
469   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
470 }
471 
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)472 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
473     const Shape& operand_shape, PrimitiveType new_element_type) {
474   auto old_element_type = operand_shape.element_type();
475   if (primitive_util::IsComplexType(old_element_type) !=
476       primitive_util::IsComplexType(new_element_type)) {
477     return Unimplemented(
478         "Unsupported conversion between real and complex types: %s => %s",
479         ShapeUtil::HumanString(operand_shape).c_str(),
480         PrimitiveType_Name(new_element_type).c_str());
481   }
482   if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
483     // Note: we may want to support tuple conversions via this operation in the
484     // future, by recursing into the tuple elements to check all sub-conversions
485     // are valid. For now we just reject them, though.
486     return InvalidArgument(
487         "cannot convert from or to tuple type; requested conversion: %s => %s",
488         ShapeUtil::HumanString(operand_shape).c_str(),
489         PrimitiveType_Name(new_element_type).c_str());
490   }
491   if (primitive_util::BitWidth(old_element_type) !=
492       primitive_util::BitWidth(new_element_type)) {
493     return InvalidArgument(
494         "cannot bitcast types with different bit-widths: %s => %s",
495         PrimitiveType_Name(old_element_type).c_str(),
496         PrimitiveType_Name(new_element_type).c_str());
497   }
498 
499   return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
500 }
501 
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)502 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
503     const Shape& operand_shape, const int exponent_bits,
504     const int mantissa_bits) {
505   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
506     return InvalidArgument(
507         "expected element type in shape to be floating point for "
508         "ReducePrecision operation; got %s",
509         PrimitiveType_Name(operand_shape.element_type()).c_str());
510   }
511   if (exponent_bits < 1) {
512     // One exponent bit is necessary to distinguish 0 from infinity.  Having
513     // no exponent bits doesn't produce a sensible number, so we require at
514     // least one.
515     return InvalidArgument("expected exponent_bits >= 1; got %d",
516                            exponent_bits);
517   }
518   if (mantissa_bits < 0) {
519     // A number with no mantissa bits is still meaningful, however.
520     return InvalidArgument("expected non-negative mantissa_bits; got %d",
521                            mantissa_bits);
522   }
523   return operand_shape;
524 }
525 
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)526 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
527     const Shape& operand_shape, const Shape& padding_value_shape,
528     const PaddingConfig& padding_config) {
529   if (ShapeUtil::IsTuple(operand_shape)) {
530     return InvalidArgument(
531         "pad operation does not support tuple-shape operands");
532   }
533   if (!ShapeUtil::IsScalar(padding_value_shape)) {
534     return InvalidArgument(
535         "pad operation does not support non-scalar padding values");
536   }
537   if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
538     return InvalidArgument(
539         "The rank of the operand and the padding configuration do not match: "
540         "%s vs %s",
541         ShapeUtil::HumanString(operand_shape).c_str(),
542         padding_config.ShortDebugString().c_str());
543   }
544   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
545                                                      padding_value_shape)) {
546     return InvalidArgument(
547         "the element types of the operands to pad do not match");
548   }
549   std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
550   for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
551     dimensions[i] = operand_shape.dimensions(i) +
552                     padding_config.dimensions(i).edge_padding_low() +
553                     padding_config.dimensions(i).edge_padding_high() +
554                     std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
555                         padding_config.dimensions(i).interior_padding();
556   }
557   return ShapeUtil::MakeShape(
558       ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
559       dimensions);
560 }
561 
562 // Current DotDimensionNumbers Requirements:
563 //
564 // Contracting Dimensions:
565 // *) Exactly one contracting dimension on both lhs and rhs.
566 // *) Contracting dimension size must be the same on both lhs and rhs.
567 // *) Contracting dimension numbers do not need to be the same (i.e. transposes
568 //    are passed on to emitter implementations).
569 //
570 // Batch Dimensions:
571 // *) Same number of batch dimensions on both lhs and rhs.
572 // *) Same batch dimension numbers (and sizes) on both lhs and rhs.
573 // *) Batch dimension numbers must be ordered before contracting and
574 //    non-contracting/non-batch dimension numbers.
575 //
576 // Non-Contracting-Non-Batch Dimensions:
577 // *) Can be 0 (matrix-vector) or 1 (matrix-matrix).
578 //
579 
580 namespace {
581 
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)582 Status ValidateDotDimensionNumbers(
583     const Shape& lhs, const Shape& rhs,
584     const DotDimensionNumbers& dimension_numbers) {
585   // Check that dimension numbers are in range.
586   auto dims_in_range =
587       [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
588          tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
589     auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
590     return std::all_of(contracting_dims.begin(), contracting_dims.end(),
591                        in_range) &&
592            std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
593   };
594 
595   tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
596       AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
597   tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
598       AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
599   tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
600       AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
601   tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
602       AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
603 
604   if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
605                      lhs_batch_dimensions) ||
606       !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
607                      rhs_batch_dimensions)) {
608     return InvalidArgument("A dimension number is out of range in dot: %s",
609                            dimension_numbers.DebugString().c_str());
610   }
611 
612   // Check that dimension numbers are unique.
613   auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
614                         tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
615     tensorflow::gtl::FlatSet<int64> dim_set;
616     auto is_unique = [&dim_set](int64 i) -> bool {
617       return dim_set.insert(i).second;
618     };
619     return std::all_of(contracting_dims.begin(), contracting_dims.end(),
620                        is_unique) &&
621            std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
622   };
623 
624   if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
625       !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
626     return InvalidArgument("A dimension number is not unique in dot: %s",
627                            dimension_numbers.DebugString().c_str());
628   }
629 
630   // Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
631   const int64 lhs_non_contracting_non_batch_dims =
632       ShapeUtil::Rank(lhs) -
633       dimension_numbers.lhs_contracting_dimensions_size() -
634       dimension_numbers.lhs_batch_dimensions_size();
635   const int64 rhs_non_contracting_non_batch_dims =
636       ShapeUtil::Rank(rhs) -
637       dimension_numbers.rhs_contracting_dimensions_size() -
638       dimension_numbers.rhs_batch_dimensions_size();
639   if (lhs_non_contracting_non_batch_dims < 0 ||
640       lhs_non_contracting_non_batch_dims > 1 ||
641       rhs_non_contracting_non_batch_dims < 0 ||
642       rhs_non_contracting_non_batch_dims > 1) {
643     return InvalidArgument(
644         "batch and contracting dimension number mismatch "
645         "with rank ");
646   }
647 
648   // Check that batch dimension numbers are ordered before all others, and
649   // that they are monotonically increasing.
650   std::vector<int64> batch_dim_numbers(lhs_batch_dimensions.size());
651   std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0);
652   if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
653                   lhs_batch_dimensions.begin()) ||
654       !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
655                   rhs_batch_dimensions.begin())) {
656     return InvalidArgument(
657         "batch dimension numbers must precede non-batch dimensions and be"
658         "monotonically increasing.");
659   }
660 
661   return Status::OK();
662 }
663 
664 }  // namespace
665 
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)666 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
667     const Shape& lhs, const Shape& rhs,
668     const DotDimensionNumbers& dimension_numbers) {
669   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
670   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
671 
672   auto fail = [lhs, rhs](const string& addendum) -> Status {
673     string message = tensorflow::strings::Printf(
674         "cannot infer shape for dot operation: %s <dot> %s",
675         ShapeUtil::HumanString(lhs).c_str(),
676         ShapeUtil::HumanString(rhs).c_str());
677     if (!addendum.empty()) {
678       message += ": " + addendum;
679     }
680     return InvalidArgument("%s", message.c_str());
681   };
682 
683   // Check if both element types are the same.
684   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
685     return fail("element types do not match");
686   }
687 
688   if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) {
689     return fail("dot only supports rank 1 or above.");
690   }
691 
692   // Validate basic properties of dot dimension numbers.
693   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
694 
695   // Check that there is only one contracting dimension for both lhs and rhs.
696   if (dimension_numbers.lhs_contracting_dimensions_size() !=
697           dimension_numbers.rhs_contracting_dimensions_size() ||
698       dimension_numbers.lhs_contracting_dimensions_size() != 1) {
699     return fail("must specify one contracting dimension for both lhs and rhs.");
700   }
701 
702   // Check that contracting dimension sizes match.
703   const int64 lhs_contracting_dimension =
704       dimension_numbers.lhs_contracting_dimensions(0);
705   const int64 rhs_contracting_dimension =
706       dimension_numbers.rhs_contracting_dimensions(0);
707   if (lhs.dimensions(lhs_contracting_dimension) !=
708       rhs.dimensions(rhs_contracting_dimension)) {
709     return fail("contracting dimension sizes do not match.");
710   }
711 
712   // Check that number of batch dimensions match.
713   if (dimension_numbers.lhs_batch_dimensions_size() !=
714       dimension_numbers.rhs_batch_dimensions_size()) {
715     return fail("must the same number of batch dimensions for lhs and rhs.");
716   }
717 
718   // Check that batch dimension numbers and sizes match.
719   for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
720     if (dimension_numbers.lhs_batch_dimensions(i) !=
721             dimension_numbers.rhs_batch_dimensions(i) ||
722         lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
723             rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
724       return fail("batch dimension numbers and sizes must match for lhs/rhs.");
725     }
726   }
727 
728   // The ranks of lhs and rhs are decremented by 1 respectively due to the
729   // contraction, and added for the rank of the result. When an input tensor is
730   // a scalar, its contribution to the rank of the result is 0.
731   // Generate the result dimensions in order, rhs dimensions followed by lhs
732   // dimensions except the contracted and batch dimensions.
733   std::vector<int64> dimensions;
734   std::unordered_set<int64> rhs_batch_dims(
735       dimension_numbers.rhs_batch_dimensions().begin(),
736       dimension_numbers.rhs_batch_dimensions().end());
737   for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
738     if (i != lhs_contracting_dimension) {
739       dimensions.push_back(lhs.dimensions(i));
740     }
741   }
742   for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
743     if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) {
744       dimensions.push_back(rhs.dimensions(i));
745     }
746   }
747   Shape result = ShapeUtil::MakeShape(
748       ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions);
749 
750   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
751   VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
752   return result;
753 }
754 
755 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(BinaryOperation operation,const Shape & lhs,const Shape & rhs)756 ShapeInference::InferDegenerateDimensionBroadcastShape(
757     BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
758   TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
759 
760   // The shapes have to be compatible. That is, if some dimension d has a
761   // different size in the two shapes, one of them has to be 1 (a "degenerate"
762   // dimension). In that case, the output shape has the non-1 dimension size
763   // from the lhs/rhs pair in every index.
764   std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
765   for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
766     if (lhs.dimensions(i) == rhs.dimensions(i)) {
767       output_dimensions[i] = lhs.dimensions(i);
768     } else if (lhs.dimensions(i) == 1) {
769       output_dimensions[i] = rhs.dimensions(i);
770     } else if (rhs.dimensions(i) == 1) {
771       output_dimensions[i] = lhs.dimensions(i);
772     } else {
773       return InvalidArgument("binary op %s with incompatible shapes: %s and %s",
774                              BinaryOperation_Name(operation).c_str(),
775                              ShapeUtil::HumanString(lhs).c_str(),
776                              ShapeUtil::HumanString(rhs).c_str());
777     }
778   }
779   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
780                               output_dimensions);
781 }
782 
InferInDimBroadcastShape(BinaryOperation operation,const Shape & smaller_shape,const Shape & larger_shape,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)783 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
784     BinaryOperation operation, const Shape& smaller_shape,
785     const Shape& larger_shape,
786     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
787   if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
788     // Reject "magic" inference for binops on different shapes, requiring
789     // the user to provide an explicit broadcast dimension in this case.
790     // See b/25177275 for more details.
791     return InvalidArgument("automatic shape inference not supported: %s and %s",
792                            ShapeUtil::HumanString(smaller_shape).c_str(),
793                            ShapeUtil::HumanString(larger_shape).c_str());
794   } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
795     return InvalidArgument(
796         "size of broadcast_dimensions has to match lower-rank operand's "
797         "rank; "
798         " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
799         "%zu",
800         ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
801   }
802 
803   // broadcast_dimensions is a sequence of dimensions; its length is equal to
804   // the rank of the lower-rank operand. The lower-rank operand's dimensions
805   // have to be compatible with the higher-rank operand's dimensions at indices
806   // specified by broadcast_dimensions. Here compatible means the dimension
807   // sizes are equal or in one of the shapes the dimension size is
808   // one. Examples:
809   //
810   // smaller_shape   larger_shape   broadcast_dimensions   output_shape
811   //   []              [2, 3]          {}                    [2, 3]
812   //   [3]             [4, 3]          {1}                   [4, 3]
813   //   [2, 3]          [2, 3, 4]       {0, 1}                [2, 3, 4]
814   //   [2, 1]          [2, 3, 4]       {0, 2}                [2, 3, 1]
815   //   [2, 3]          [2, 1, 4]       {0, 1}                [2, 3, 4]
816   //
817   // The column output_shape may not be the final shape of the XLA
818   // operation. After the "InDim" broadcasting implemented in this function
819   // expands the rank, degenerate-dimension broadcasting (implemented in
820   // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
821   // up to match the dimension size of the other operand. For example, consider
822   // the row in the table above with a smaller_shape of [2, 1]. The shape
823   // returned by this function is [2, 3, 1] (output_shape) however, the result
824   // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
825   // broadcasting.
826   //
827   // Invalid broadcasts:
828   //
829   // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
830   // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
831   //   dimension zero of smaller_shape(size 3). **Zero here comes from the value
832   //   in broadcast_dimensions.
833   //
834   // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
835   // Reason: Dimension one of larger_shape (size 3) is not compatible with
836   //   dimension zero of smaller_shape(size 2)
837 
838   // The output shape is initially the larger_shape. Sizes of dimensions
839   // specified in broadcast_dimensions are then changed to match the
840   // corresponding dimension size in smaller_shape.
841   Shape output_shape(larger_shape);
842   output_shape.set_element_type(
843       ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
844 
845   for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
846     int64 dimension_to_match = broadcast_dimensions.at(i);
847     if (dimension_to_match < 0) {
848       return InvalidArgument(
849           "broadcast dimension number (%lld) cannot be negative",
850           dimension_to_match);
851     }
852     if (dimension_to_match >= larger_shape.dimensions_size()) {
853       return InvalidArgument(
854           "broadcast dimension number (%lld) too large; higher-rank "
855           "operand has rank %d",
856           dimension_to_match, larger_shape.dimensions_size());
857     }
858     int64 small_dimension_size = smaller_shape.dimensions(i);
859     int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
860     // Dimension sizes must be compatible: match or be degenerate (degenerate
861     // case is handled by degenerate dimension broadcasting which occurs after
862     // InDim broadcasting).
863     if (small_dimension_size != large_dimension_size &&
864         small_dimension_size != 1 && large_dimension_size != 1) {
865       return InvalidArgument(
866           "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i,
867           small_dimension_size, large_dimension_size,
868           ShapeUtil::HumanString(smaller_shape).c_str(),
869           ShapeUtil::HumanString(larger_shape).c_str());
870     }
871     // Make sure the broadcast dimensions are listed in a strictly increasing
872     // order.
873     if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
874       return InvalidArgument(
875           "broadcast dimensions order is wrong: %lld comes after %lld",
876           dimension_to_match, broadcast_dimensions.at(i - 1));
877     }
878 
879     output_shape.set_dimensions(dimension_to_match, small_dimension_size);
880   }
881 
882   return output_shape;
883 }
884 
InferElementwiseBinaryOpShape(BinaryOperation operation,const Shape & lhs,const Shape & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)885 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
886     BinaryOperation operation, const Shape& lhs, const Shape& rhs,
887     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
888   TF_RETURN_IF_ERROR(
889       ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
890   TF_RETURN_IF_ERROR(
891       ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
892 
893   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
894     return InvalidArgument(
895         "binary op %s with different element types: %s and %s",
896         BinaryOperation_Name(operation).c_str(),
897         ShapeUtil::HumanString(lhs).c_str(),
898         ShapeUtil::HumanString(rhs).c_str());
899   }
900 
901   if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
902     std::vector<int64> identity_dims(ShapeUtil::Rank(lhs));
903     std::iota(identity_dims.begin(), identity_dims.end(), 0);
904     if (!broadcast_dimensions.empty() &&
905         broadcast_dimensions != identity_dims) {
906       return InvalidArgument(
907           "broadcast dimensions field must either be not set or be the "
908           "identity on binary operations with operands of the same rank");
909     }
910   }
911 
912   if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
913     // If the shapes are the same other than layout, the output shape is the
914     // same (elementwise op).
915     return ShapeUtil::ChangeElementType(
916         lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
917   }
918 
919   if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
920     return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
921   } else {
922     // Ranks do not match, so perform InDim broadcasting using
923     // broadcast_dimensions. Scalar broadcasting is a special case of this.
924     const Shape& larger_shape =
925         ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
926     const Shape& smaller_shape =
927         ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
928 
929     // After InDim broadcasting, perform degenerate dimensions broadcasting.
930     TF_ASSIGN_OR_RETURN(
931         Shape indim_broadcast_shape,
932         InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
933                                  broadcast_dimensions));
934 
935     return InferDegenerateDimensionBroadcastShape(
936         operation, indim_broadcast_shape, larger_shape);
937   }
938 }
939 
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)940 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
941     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
942   return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(),
943                             rhs->shape(), /*broadcast_dimensions=*/{});
944 }
945 
InferBinaryOpShape(BinaryOperation operation,const Shape & lhs,const Shape & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)946 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
947     BinaryOperation operation, const Shape& lhs, const Shape& rhs,
948     tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
949   VLOG(2) << tensorflow::strings::Printf(
950       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
951       BinaryOperation_Name(operation).c_str(),
952       ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
953       Join(broadcast_dimensions, ", ").c_str());
954   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
955   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
956 
957   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
958       lhs, tensorflow::strings::StrCat("lhs of binary operation ",
959                                        BinaryOperation_Name(operation))));
960   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
961       rhs, tensorflow::strings::StrCat("rhs of binary operation ",
962                                        BinaryOperation_Name(operation))));
963   switch (operation) {
964     case BINOP_MAX:
965     case BINOP_MIN:
966     case BINOP_SUB:
967     case BINOP_ADD:
968     case BINOP_ATAN2:
969     case BINOP_POW:
970     case BINOP_DIV:
971     case BINOP_REM:
972     case BINOP_MUL:
973     case BINOP_SHIFT_LEFT:
974     case BINOP_SHIFT_RIGHT_ARITHMETIC:
975     case BINOP_SHIFT_RIGHT_LOGICAL:
976       return InferElementwiseBinaryOpShape(operation, lhs, rhs,
977                                            broadcast_dimensions);
978 
979     case BINOP_COMPLEX: {
980       if (!ShapeUtil::ElementIsFloating(lhs)) {
981         return InvalidArgument(
982             "expected element type in shape to be floating for complex compose "
983             "operation; got %s",
984             PrimitiveType_Name(lhs.element_type()).c_str());
985       }
986       TF_ASSIGN_OR_RETURN(const Shape& shape,
987                           InferElementwiseBinaryOpShape(operation, lhs, rhs,
988                                                         broadcast_dimensions));
989       if (lhs.element_type() == F32 && rhs.element_type() == F32) {
990         return ShapeUtil::ChangeElementType(shape, C64);
991       } else {
992         return Unimplemented("complex component type not supported");
993       }
994     }
995     case BINOP_AND:
996     case BINOP_OR:
997       if (lhs.element_type() != PRED &&
998           !primitive_util::IsIntegralType(lhs.element_type())) {
999         return InvalidArgument(
1000             "expected pred or integral type in argument to and/or operation; "
1001             "got %s",
1002             PrimitiveType_Name(lhs.element_type()).c_str());
1003       }
1004       return InferElementwiseBinaryOpShape(operation, lhs, rhs,
1005                                            broadcast_dimensions);
1006     case BINOP_EQ:
1007     case BINOP_GE:
1008     case BINOP_GT:
1009     case BINOP_LE:
1010     case BINOP_LT:
1011     case BINOP_NE: {
1012       TF_ASSIGN_OR_RETURN(const Shape& shape,
1013                           InferElementwiseBinaryOpShape(operation, lhs, rhs,
1014                                                         broadcast_dimensions));
1015       return ShapeUtil::ChangeElementType(shape, PRED);
1016     }
1017     default:
1018       return Unimplemented(
1019           "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s",
1020           BinaryOperation_Name(operation).c_str(),
1021           lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
1022   }
1023 }
1024 
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1025 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1026     HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1027     const HloInstruction* ehs) {
1028   return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(),
1029                              rhs->shape(), ehs->shape());
1030 }
1031 
InferTernaryOpShape(TernaryOperation operation,const Shape & lhs,const Shape & rhs,const Shape & ehs)1032 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1033     TernaryOperation operation, const Shape& lhs, const Shape& rhs,
1034     const Shape& ehs) {
1035   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1036   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1037   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1038   switch (operation) {
1039     case TRIOP_CLAMP:
1040       return InferClampShape(lhs, rhs, ehs);
1041     case TRIOP_SELECT:
1042       return InferSelectShape(lhs, rhs, ehs);
1043     default:
1044       return InvalidArgument("unknown operation %s",
1045                              TernaryOperation_Name(operation).c_str());
1046   }
1047 }
1048 
InferVariadicOpShape(HloOpcode opcode,tensorflow::gtl::ArraySlice<const HloInstruction * > operands)1049 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1050     HloOpcode opcode,
1051     tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
1052   std::vector<const Shape*> operand_shapes;
1053   for (const HloInstruction* operand : operands) {
1054     operand_shapes.push_back(&operand->shape());
1055   }
1056   return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
1057                               operand_shapes);
1058 }
1059 
InferVariadicOpShape(VariadicOperation operation,tensorflow::gtl::ArraySlice<const Shape * > operand_shapes)1060 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1061     VariadicOperation operation,
1062     tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
1063   for (const Shape* shape : operand_shapes) {
1064     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1065   }
1066   switch (operation) {
1067     case VAROP_TUPLE: {
1068       Shape result = ShapeUtil::MakeTupleShape({});
1069       for (const Shape* shape : operand_shapes) {
1070         ShapeUtil::AppendShapeToTuple(*shape, &result);
1071       }
1072       return result;
1073     }
1074     default:
1075       return InvalidArgument("unknown operation %s",
1076                              VariadicOperation_Name(operation).c_str());
1077   }
1078 }
1079 
InferMapShape(tensorflow::gtl::ArraySlice<const Shape * > arg_shapes,const ProgramShape & to_apply,tensorflow::gtl::ArraySlice<int64> dimensions)1080 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1081     tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
1082     const ProgramShape& to_apply,
1083     tensorflow::gtl::ArraySlice<int64> dimensions) {
1084   if (arg_shapes.empty()) {
1085     return InvalidArgument("Map expects at least one argument");
1086   }
1087 
1088   // All arguments must have the same shape.
1089   const Shape* arg_shape = arg_shapes[0];
1090   for (size_t i = 1; i < arg_shapes.size(); ++i) {
1091     TF_RETURN_IF_ERROR(
1092         ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
1093 
1094     if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1095       continue;
1096     }
1097     if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
1098         !ShapeUtil::IsTuple(*arg_shape) &&
1099         ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1100                                                       *arg_shape)) {
1101       if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1102         continue;
1103       }
1104       if (ShapeUtil::IsScalar(*arg_shape)) {
1105         arg_shape = arg_shapes[i];
1106         continue;
1107       }
1108     }
1109 
1110     std::vector<string> pieces;
1111     for (const Shape* shape : arg_shapes) {
1112       pieces.push_back(ShapeUtil::HumanString(*shape));
1113     }
1114     return InvalidArgument(
1115         "Map operation requires all operands to have the same shape; got: "
1116         "%s",
1117         Join(pieces, ", ").c_str());
1118   }
1119 
1120   // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1121   // only support mapping across all dimensions: i.e. scalar map functions).
1122   if (dimensions.size() != arg_shape->dimensions_size()) {
1123     return InvalidArgument(
1124         "Map applied to a subset of dimensions currently not supported: "
1125         "arg_dimension_size: %d, requested_map_dimensions_size: %zu",
1126         arg_shape->dimensions_size(), dimensions.size());
1127   }
1128 
1129   // Check that requested map dimensions numbers are monotonically increasing.
1130   for (int i = 0; i < dimensions.size(); ++i) {
1131     if (dimensions[i] != i) {
1132       return InvalidArgument(
1133           "Map requires monotonically increasing dimension numbers, found: %s ",
1134           Join(dimensions, ", ").c_str());
1135     }
1136   }
1137 
1138   // The applied function's arity equals the number of arguments.
1139   if (arg_shapes.size() != to_apply.parameters_size()) {
1140     return InvalidArgument(
1141         "Map applied function arity must match number of arguments; got: "
1142         "arity: %d, arguments: %zu",
1143         to_apply.parameters_size(), arg_shapes.size());
1144   }
1145 
1146   // The parameters should all be scalars, and the output too.
1147   const Shape& output_shape = to_apply.result();
1148   if (!ShapeUtil::IsScalar(output_shape)) {
1149     return InvalidArgument(
1150         "mapped computation's result has to be a scalar; "
1151         "got: %s",
1152         ShapeUtil::HumanString(output_shape).c_str());
1153   }
1154 
1155   for (int i = 0; i < to_apply.parameters_size(); ++i) {
1156     const Shape& parameter_shape = to_apply.parameters(i);
1157 
1158     if (!ShapeUtil::IsScalar(parameter_shape)) {
1159       return InvalidArgument(
1160           "mapped computation's parameter has to be a scalar; "
1161           "got parameter %d shape: %s",
1162           i, ShapeUtil::HumanString(parameter_shape).c_str());
1163     }
1164 
1165     if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1166                                                        *arg_shape)) {
1167       return InvalidArgument(
1168           "mapped computation's parameter type has to match argument element "
1169           "type; got parameter %d shape: %s, argument shape: %s",
1170           i, ShapeUtil::HumanString(parameter_shape).c_str(),
1171           ShapeUtil::HumanString(*arg_shape).c_str());
1172     }
1173   }
1174 
1175   return ShapeUtil::MakeShape(output_shape.element_type(),
1176                               AsInt64Slice(arg_shape->dimensions()));
1177 }
1178 
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1179 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1180     const Shape& operand_shape, const Shape& scale_shape,
1181     const Shape& offset_shape, int64 feature_index) {
1182   TF_RETURN_IF_ERROR(
1183       ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
1184   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1185       offset_shape, "offset input of batch norm training"));
1186   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1187       scale_shape, "scale input of batch norm training"));
1188 
1189   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1190                tensorflow::Status::OK());
1191   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1192                tensorflow::Status::OK());
1193   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1194                tensorflow::Status::OK());
1195 
1196   if (feature_index >= ShapeUtil::Rank(operand_shape)) {
1197     return InvalidArgument(
1198         "Expected feature_index of batch-norm-training to be "
1199         "smaller than the rank of operand_shape; "
1200         "got feature_index %lld, and rank %lld",
1201         feature_index, ShapeUtil::Rank(operand_shape));
1202   }
1203 
1204   if (feature_index < 0) {
1205     return InvalidArgument(
1206         "Expected feature_index of batch-norm-training to "
1207         "be a non-negative number, got %lld",
1208         feature_index);
1209   }
1210 
1211   if (ShapeUtil::Rank(operand_shape) < 1) {
1212     return InvalidArgument(
1213         "Expected the rank of operand to "
1214         "batch-norm-training to be at least 1; got %lld",
1215         ShapeUtil::Rank(operand_shape));
1216   }
1217 
1218   if (ShapeUtil::Rank(offset_shape) != 1) {
1219     return InvalidArgument(
1220         "Offset input of batch-norm-training must have"
1221         " rank 1, but has rank %lld.",
1222         ShapeUtil::Rank(offset_shape));
1223   }
1224 
1225   if (ShapeUtil::Rank(scale_shape) != 1) {
1226     return InvalidArgument(
1227         "Scale input of batch-norm-training must have"
1228         " rank 1, but has rank %lld.",
1229         ShapeUtil::Rank(scale_shape));
1230   }
1231 
1232   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1233     return InvalidArgument(
1234         "The operand to batch-norm-training must have a floating point "
1235         "element type, but the shape is %s",
1236         PrimitiveType_Name(operand_shape.element_type()).c_str());
1237   }
1238 
1239   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1240                                                      operand_shape)) {
1241     return InvalidArgument(
1242         "The inputs should have the same element type for batch-norm-training, "
1243         "but the shape of offset factor is %s "
1244         "and the shape of operand is %s",
1245         PrimitiveType_Name(offset_shape.element_type()).c_str(),
1246         PrimitiveType_Name(operand_shape.element_type()).c_str());
1247   }
1248 
1249   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1250                                                      operand_shape)) {
1251     return InvalidArgument(
1252         "The inputs should have the same element type for batch-norm-training, "
1253         "but the shape of scale factor is %s "
1254         "and the shape of operand is %s",
1255         PrimitiveType_Name(scale_shape.element_type()).c_str(),
1256         PrimitiveType_Name(operand_shape.element_type()).c_str());
1257   }
1258 
1259   const int64 feature_count = operand_shape.dimensions(feature_index);
1260   Shape output_shape_for_mean_and_var =
1261       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1262 
1263   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1264     return InvalidArgument(
1265         "The size of offset factor should be the same as feature count,"
1266         "but the size of offset factor is %lld "
1267         "and the feature count is %lld",
1268         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1269   }
1270 
1271   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1272     return InvalidArgument(
1273         "The size of scale factor should be the same as feature count,"
1274         "but the size of scale factor is %lld "
1275         "and the feature count is %lld",
1276         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1277   }
1278 
1279   return ShapeUtil::MakeTupleShape({operand_shape,
1280                                     output_shape_for_mean_and_var,
1281                                     output_shape_for_mean_and_var});
1282 }
1283 
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1284 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1285     const Shape& operand_shape, const Shape& scale_shape,
1286     const Shape& offset_shape, const Shape& mean_shape,
1287     const Shape& variance_shape, int64 feature_index) {
1288   TF_RETURN_IF_ERROR(
1289       ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
1290   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1291       offset_shape, "offset input of batch norm inference"));
1292   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1293       scale_shape, "scale input of batch norm inference"));
1294 
1295   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1296                tensorflow::Status::OK());
1297   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1298                tensorflow::Status::OK());
1299   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1300                tensorflow::Status::OK());
1301   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) ==
1302                tensorflow::Status::OK());
1303   TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
1304                tensorflow::Status::OK());
1305 
1306   if (feature_index >= ShapeUtil::Rank(operand_shape)) {
1307     return InvalidArgument(
1308         "Expected feature_index of batch-norm-inference to be "
1309         "smaller than the rank of operand_shape; "
1310         "got feature_index %lld, and rank %lld",
1311         feature_index, ShapeUtil::Rank(operand_shape));
1312   }
1313 
1314   if (feature_index < 0) {
1315     return InvalidArgument(
1316         "Expected feature_index of batch-norm-inference to "
1317         "be a non-negative number, got %lld",
1318         feature_index);
1319   }
1320 
1321   if (ShapeUtil::Rank(operand_shape) < 1) {
1322     return InvalidArgument(
1323         "Expected the rank of operand to "
1324         "batch-norm-inference to be at least 1; got %lld",
1325         ShapeUtil::Rank(operand_shape));
1326   }
1327 
1328   if (ShapeUtil::Rank(offset_shape) != 1) {
1329     return InvalidArgument(
1330         "Offset input of batch-norm-inference must have"
1331         " rank 1, but has rank %lld.",
1332         ShapeUtil::Rank(offset_shape));
1333   }
1334 
1335   if (ShapeUtil::Rank(scale_shape) != 1) {
1336     return InvalidArgument(
1337         "Scale input of batch-norm-inference must have"
1338         " rank 1, but has rank %lld.",
1339         ShapeUtil::Rank(scale_shape));
1340   }
1341 
1342   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1343     return InvalidArgument(
1344         "The operand to batch-norm-inference must have a floating point "
1345         "element type, but the shape is %s",
1346         PrimitiveType_Name(operand_shape.element_type()).c_str());
1347   }
1348 
1349   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1350                                                      operand_shape)) {
1351     return InvalidArgument(
1352         "The inputs should have the same element type for "
1353         "batch-norm-inference, "
1354         "but the shape of offset factor is %s "
1355         "and the shape of operand is %s",
1356         PrimitiveType_Name(offset_shape.element_type()).c_str(),
1357         PrimitiveType_Name(operand_shape.element_type()).c_str());
1358   }
1359 
1360   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1361                                                      operand_shape)) {
1362     return InvalidArgument(
1363         "The inputs should have the same element type for "
1364         "batch-norm-inference, "
1365         "but the shape of scale factor is %s "
1366         "and the shape of operand is %s",
1367         PrimitiveType_Name(scale_shape.element_type()).c_str(),
1368         PrimitiveType_Name(operand_shape.element_type()).c_str());
1369   }
1370 
1371   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1372                                                      operand_shape)) {
1373     return InvalidArgument(
1374         "The inputs should have the same element type for "
1375         "batch-norm-inference, "
1376         "but the shape of mean is %s "
1377         "and the shape of operand is %s",
1378         PrimitiveType_Name(mean_shape.element_type()).c_str(),
1379         PrimitiveType_Name(operand_shape.element_type()).c_str());
1380   }
1381 
1382   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1383                                                      operand_shape)) {
1384     return InvalidArgument(
1385         "The inputs should have the same element type for "
1386         "batch-norm-inference, "
1387         "but the shape of variance is %s "
1388         "and the shape of operand is %s",
1389         PrimitiveType_Name(mean_shape.element_type()).c_str(),
1390         PrimitiveType_Name(variance_shape.element_type()).c_str());
1391   }
1392 
1393   const int64 feature_count = operand_shape.dimensions(feature_index);
1394   Shape output_shape_for_mean_and_var =
1395       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1396 
1397   if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1398     return InvalidArgument(
1399         "The size of offset factor should be the same as feature count,"
1400         "but the size of offset factor is %lld "
1401         "and the feature count is %lld",
1402         ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1403   }
1404 
1405   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1406     return InvalidArgument(
1407         "The size of scale factor should be the same as feature count,"
1408         "but the size of scale factor is %lld "
1409         "and the feature count is %lld",
1410         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1411   }
1412 
1413   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1414     return InvalidArgument(
1415         "The size of mean should be the same as feature count,"
1416         "but the size of mean is %lld "
1417         "and the feature count is %lld",
1418         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1419   }
1420 
1421   if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1422     return InvalidArgument(
1423         "The size of variance should be the same as feature count,"
1424         "but the size of variance is %lld "
1425         "and the feature count is %lld",
1426         ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1427   }
1428 
1429   return operand_shape;
1430 }
1431 
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1432 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1433     const Shape& operand_shape, const Shape& scale_shape,
1434     const Shape& mean_shape, const Shape& var_shape,
1435     const Shape& output_grad_shape, int64 feature_index) {
1436   TF_RETURN_IF_ERROR(
1437       ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
1438   TF_RETURN_IF_ERROR(
1439       ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
1440   TF_RETURN_IF_ERROR(
1441       ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
1442   TF_RETURN_IF_ERROR(
1443       ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
1444   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1445       output_grad_shape, "output_grad input of batch norm grad"));
1446 
1447   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1448   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1449   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1450   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1451   TF_RETURN_IF_ERROR(
1452       ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1453 
1454   if (feature_index >= ShapeUtil::Rank(operand_shape)) {
1455     return InvalidArgument(
1456         "Expected feature_index of batch-norm-grad to be "
1457         "smaller than the rank of operand_shape; "
1458         "got feature_index %lld, and rank %lld",
1459         feature_index, ShapeUtil::Rank(operand_shape));
1460   }
1461 
1462   if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
1463     return InvalidArgument(
1464         "Expected operand_shape of batch-norm-grad to have the same rank as"
1465         " output_grad_shape; got rank(oprand_shape) %lld, and"
1466         " rank(output_grad_shape) %lld",
1467         ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
1468   }
1469 
1470   if (ShapeUtil::Rank(mean_shape) != 1) {
1471     return InvalidArgument(
1472         "Mean input of batch-norm-grad must have"
1473         " rank 1, but has rank %lld.",
1474         ShapeUtil::Rank(mean_shape));
1475   }
1476 
1477   if (ShapeUtil::Rank(scale_shape) != 1) {
1478     return InvalidArgument(
1479         "Scale input of batch-norm-grad must have"
1480         " rank 1, but has rank %lld.",
1481         ShapeUtil::Rank(scale_shape));
1482   }
1483 
1484   if (ShapeUtil::Rank(var_shape) != 1) {
1485     return InvalidArgument(
1486         "Var input of batch-norm-grad must have"
1487         " rank 1, but has rank %lld.",
1488         ShapeUtil::Rank(var_shape));
1489   }
1490 
1491   if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1492     return InvalidArgument(
1493         "The operand to batch-norm-grad must have a floating point "
1494         "element type, but the shape is %s",
1495         PrimitiveType_Name(operand_shape.element_type()).c_str());
1496   }
1497 
1498   if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1499     return InvalidArgument(
1500         "The output_grad to batch-norm-grad must have a floating point "
1501         "element type, but the shape is %s",
1502         PrimitiveType_Name(output_grad_shape.element_type()).c_str());
1503   }
1504 
1505   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1506                                                      operand_shape)) {
1507     return InvalidArgument(
1508         "The inputs should have the same element type for batch-norm-grad, "
1509         "but the element type of output_grad is %s "
1510         "and the element type of operand is %s",
1511         PrimitiveType_Name(output_grad_shape.element_type()).c_str(),
1512         PrimitiveType_Name(operand_shape.element_type()).c_str());
1513   }
1514 
1515   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1516                                                      operand_shape)) {
1517     return InvalidArgument(
1518         "The inputs should have the same element type for batch-norm-grad, "
1519         "but the element type of scale factor is %s "
1520         "and the element type of operand is %s",
1521         PrimitiveType_Name(scale_shape.element_type()).c_str(),
1522         PrimitiveType_Name(operand_shape.element_type()).c_str());
1523   }
1524 
1525   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1526                                                      operand_shape)) {
1527     return InvalidArgument(
1528         "The inputs should have the same element type for batch-norm-grad, "
1529         "but the element type of mean is %s "
1530         "and the element type of operand is %s",
1531         PrimitiveType_Name(mean_shape.element_type()).c_str(),
1532         PrimitiveType_Name(operand_shape.element_type()).c_str());
1533   }
1534 
1535   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1536                                                      operand_shape)) {
1537     return InvalidArgument(
1538         "The inputs should have the same element type for batch-norm-grad, "
1539         "but the element type of mean is %s "
1540         "and the element type of operand is %s",
1541         PrimitiveType_Name(mean_shape.element_type()).c_str(),
1542         PrimitiveType_Name(operand_shape.element_type()).c_str());
1543   }
1544 
1545   const int64 feature_count = operand_shape.dimensions(feature_index);
1546 
1547   Shape feature_shape =
1548       ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1549 
1550   if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1551     return InvalidArgument(
1552         "The size of mean should be the same as feature count,"
1553         "but the size of offset factor is %lld "
1554         "and the feature count is %lld",
1555         ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1556   }
1557 
1558   if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1559     return InvalidArgument(
1560         "The size of scale factor should be the same as feature count,"
1561         "but the size of scale factor is %lld "
1562         "and the feature count is %lld",
1563         ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1564   }
1565 
1566   if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1567     return InvalidArgument(
1568         "The size of variance should be the same as feature count,"
1569         "but the size of variance is %lld "
1570         "and the feature count is %lld",
1571         ShapeUtil::GetDimension(var_shape, 0), feature_count);
1572   }
1573 
1574   // Verify operand_shape and output_grad_shape have same bounds.
1575   for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
1576     if (ShapeUtil::GetDimension(operand_shape, i) !=
1577         ShapeUtil::GetDimension(output_grad_shape, i)) {
1578       return InvalidArgument(
1579           "The bounds of operand shape should be the same as output_grad's,"
1580           "but the bound of operand_shape at dimension %lld is %lld "
1581           "and the bound of output_grad_shape is %lld",
1582           i, ShapeUtil::GetDimension(operand_shape, i),
1583           ShapeUtil::GetDimension(output_grad_shape, i));
1584     }
1585   }
1586 
1587   return ShapeUtil::MakeTupleShape(
1588       {operand_shape, feature_shape, feature_shape});
1589 }
1590 
InferConvolveShape(const Shape & lhs,const Shape & rhs,const Window & window,const ConvolutionDimensionNumbers & dnums)1591 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1592     const Shape& lhs, const Shape& rhs, const Window& window,
1593     const ConvolutionDimensionNumbers& dnums) {
1594   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
1595   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
1596 
1597   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
1598     return InvalidArgument(
1599         "Convolution with different element types: %s and %s",
1600         ShapeUtil::HumanString(lhs).c_str(),
1601         ShapeUtil::HumanString(rhs).c_str());
1602   }
1603   if (dnums.input_spatial_dimensions_size() !=
1604       dnums.kernel_spatial_dimensions_size()) {
1605     return InvalidArgument(
1606         "Both arguments to convolution must have same number of dimensions.\n"
1607         "Window: %s",
1608         window.DebugString().c_str());
1609   }
1610 
1611   const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1612   if (window.dimensions_size() != num_spatial_dims) {
1613     return InvalidArgument(
1614         "Window must have same number of dimensions as dimension numbers.\n"
1615         "Window: %s\nDimension numbers: %s",
1616         window.DebugString().c_str(), dnums.DebugString().c_str());
1617   }
1618 
1619   const int num_dims = num_spatial_dims + 2;
1620   if (ShapeUtil::Rank(lhs) != num_dims) {
1621     return InvalidArgument(
1622         "The LHS argument to a convolution should have rank %d.\n"
1623         "lhs: %s",
1624         num_dims, ShapeUtil::HumanString(lhs).c_str());
1625   }
1626   if (ShapeUtil::Rank(rhs) != num_dims) {
1627     return InvalidArgument(
1628         "The RHS argument to a convolution should have rank %d.\n"
1629         "lhs: %s",
1630         num_dims, ShapeUtil::HumanString(lhs).c_str());
1631   }
1632   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1633   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1634 
1635   // Verifies that the input and window dimensions are a permutation of
1636   // the dimension numbers.
1637   std::vector<int64> input_dnums(num_dims);
1638   input_dnums[0] = dnums.input_batch_dimension();
1639   input_dnums[1] = dnums.input_feature_dimension();
1640   std::copy(dnums.input_spatial_dimensions().begin(),
1641             dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1642   std::sort(input_dnums.begin(), input_dnums.end());
1643 
1644   std::vector<int64> window_dnums(num_dims);
1645   window_dnums[0] = dnums.kernel_input_feature_dimension();
1646   window_dnums[1] = dnums.kernel_output_feature_dimension();
1647   std::copy(dnums.kernel_spatial_dimensions().begin(),
1648             dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1649   std::sort(window_dnums.begin(), window_dnums.end());
1650 
1651   std::vector<int64> output_dnums(num_dims);
1652   output_dnums[0] = dnums.output_batch_dimension();
1653   output_dnums[1] = dnums.output_feature_dimension();
1654   std::copy(dnums.output_spatial_dimensions().begin(),
1655             dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1656   std::sort(output_dnums.begin(), output_dnums.end());
1657 
1658   std::vector<int64> expected_dnums(num_dims);
1659   std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1660 
1661   const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1662   if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
1663       !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
1664       !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
1665     return InvalidArgument(
1666         "A dimension number is out of range in convolution: %s",
1667         dnums.DebugString().c_str());
1668   }
1669 
1670   if (input_dnums != expected_dnums) {
1671     return InvalidArgument(
1672         "Input dimensions of convolution must contain each dimension exactly "
1673         "once: %s",
1674         dnums.DebugString().c_str());
1675   }
1676   if (window_dnums != expected_dnums) {
1677     return InvalidArgument(
1678         "Window dimensions of convolution must contain each dimension exactly "
1679         "once: %s",
1680         dnums.DebugString().c_str());
1681   }
1682   if (output_dnums != expected_dnums) {
1683     return InvalidArgument(
1684         "Output dimensions of convolution must contain each dimension exactly "
1685         "once: %s",
1686         dnums.DebugString().c_str());
1687   }
1688 
1689   std::vector<int64> input_spatial_dims(num_spatial_dims);
1690   for (int i = 0; i < num_spatial_dims; ++i) {
1691     input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1692   }
1693   const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1694   const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1695 
1696   std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1697   for (int i = 0; i < num_spatial_dims; ++i) {
1698     kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1699   }
1700   const int64 kernel_input_features =
1701       rhs.dimensions(dnums.kernel_input_feature_dimension());
1702   const int64 kernel_output_features =
1703       rhs.dimensions(dnums.kernel_output_feature_dimension());
1704 
1705   if (input_features != kernel_input_features) {
1706     return InvalidArgument(
1707         "Expected LHS feature dimension (value %lld) to match RHS "
1708         "input feature dimension (value %lld); got <conv>(%s, %s)\n"
1709         "Dimension numbers: {%s}",
1710         input_features, kernel_input_features,
1711         ShapeUtil::HumanString(lhs).c_str(),
1712         ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
1713   }
1714   std::vector<int64> window_dims(num_spatial_dims);
1715   for (int i = 0; i < num_spatial_dims; ++i) {
1716     window_dims[i] = window.dimensions(i).size();
1717   }
1718   if (kernel_spatial_dims != window_dims) {
1719     return InvalidArgument(
1720         "Window dimensions do not match RHS shape:\n\t"
1721         "RHS shape: %s\n\t"
1722         "Window: {%s}\n\t"
1723         "Dimension numbers: {%s}",
1724         ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
1725         dnums.ShortDebugString().c_str());
1726   }
1727 
1728   Shape base_shape =
1729       ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1730   TF_ASSIGN_OR_RETURN(
1731       Shape window_output_shape,
1732       InferWindowOutputShape(base_shape, window, lhs.element_type(),
1733                              /*allow_negative_padding=*/true));
1734 
1735   std::vector<int64> dimensions(num_dims);
1736   dimensions[dnums.output_batch_dimension()] = input_batch;
1737   dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1738   for (int i = 0; i < num_spatial_dims; ++i) {
1739     dimensions[dnums.output_spatial_dimensions(i)] =
1740         window_output_shape.dimensions(i);
1741   }
1742   return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1743                               dimensions);
1744 }
1745 
InferFftShape(const Shape & in,const FftType fft_type,const tensorflow::gtl::ArraySlice<int64> fft_length)1746 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1747     const Shape& in, const FftType fft_type,
1748     const tensorflow::gtl::ArraySlice<int64> fft_length) {
1749   const int64 fft_rank = fft_length.size();
1750   if (fft_rank < 1 || fft_rank > 3) {
1751     return InvalidArgument("FFT only supports ranks 1-3, but got %lld",
1752                            fft_rank);
1753   }
1754 #define RET_CHECK_RANK(x)                              \
1755   if (x.dimensions_size() < fft_rank) {                \
1756     return InvalidArgument(                            \
1757         "FFT of rank %lld requires input of at least " \
1758         "same rank; got input of rank %d",             \
1759         fft_rank, x.dimensions_size());                \
1760   }
1761   switch (fft_type) {
1762     case FFT:
1763     case IFFT:
1764       if (in.element_type() != C64) {
1765         return InvalidArgument("%s requires C64 input type, found %s",
1766                                FftType_Name(fft_type).c_str(),
1767                                PrimitiveType_Name(in.element_type()).c_str());
1768       }
1769       RET_CHECK_RANK(in);
1770       return in;
1771     case RFFT: {
1772       if (in.element_type() != F32) {
1773         return InvalidArgument("RFFT requires F32 input type, found %s",
1774                                PrimitiveType_Name(in.element_type()).c_str());
1775       }
1776       RET_CHECK_RANK(in);
1777       for (int i = 0; i < fft_rank; i++) {
1778         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1779             fft_length[i]) {
1780           return InvalidArgument(
1781               "RFFT requires innermost dimensions match fft_length but "
1782               "dimension %lld is %lld and should be %lld",
1783               in.dimensions_size() - fft_rank + i,
1784               in.dimensions(in.dimensions_size() - fft_rank + i),
1785               fft_length[i]);
1786         }
1787       }
1788       Shape result = ShapeUtil::ChangeElementType(in, C64);
1789       result.set_dimensions(result.dimensions_size() - 1,
1790                             fft_length[fft_rank - 1] / 2 + 1);
1791       return result;
1792     }
1793     case IRFFT: {
1794       if (in.element_type() != C64) {
1795         return InvalidArgument("IRFFT requires C64 input type, found %s",
1796                                PrimitiveType_Name(in.element_type()).c_str());
1797       }
1798       RET_CHECK_RANK(in);
1799       Shape result = ShapeUtil::ComplexComponentShape(in);
1800       for (int i = 0; i < fft_rank - 1; i++) {
1801         if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1802             fft_length[i]) {
1803           return InvalidArgument(
1804               "IRFFT requires all but one innermost dimensions match "
1805               "fft_length, but dimension %lld is %lld and should be %lld",
1806               in.dimensions_size() - fft_rank + i,
1807               in.dimensions(in.dimensions_size() - fft_rank + i),
1808               fft_length[i]);
1809         }
1810       }
1811       if (in.dimensions(in.dimensions_size() - 1) !=
1812           fft_length[fft_rank - 1] / 2 + 1) {
1813         return InvalidArgument(
1814             "IRFFT requires innermost dimension matches fft_length/2+1, but "
1815             "dimension %d is %lld and should be %lld",
1816             in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1817             fft_length[fft_rank - 1] / 2 + 1);
1818       }
1819       result.set_dimensions(result.dimensions_size() - 1,
1820                             fft_length[fft_rank - 1]);
1821       return result;
1822     }
1823     default:
1824       LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1825   }
1826 #undef RET_CHECK_RANK
1827 }
1828 
InferCrossReplicaSumShape(tensorflow::gtl::ArraySlice<const Shape * > operand_shapes)1829 /* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
1830     tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
1831   for (const Shape* operand_shape : operand_shapes) {
1832     TF_RETURN_IF_ERROR(
1833         ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum"));
1834   }
1835   if (operand_shapes.size() == 1) {
1836     return *operand_shapes[0];
1837   }
1838   std::vector<Shape> operand_shape_values;
1839   for (const Shape* operand_shape : operand_shapes) {
1840     operand_shape_values.push_back(*operand_shape);
1841   }
1842   return ShapeUtil::MakeTupleShape(operand_shape_values);
1843 }
1844 
InferReduceShape(const Shape & arg,const Shape & init_value,tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,const ProgramShape & to_apply)1845 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
1846     const Shape& arg, const Shape& init_value,
1847     tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
1848     const ProgramShape& to_apply) {
1849   // Check that the dimension to reduce are in-bounds for the given shape.
1850   for (int64 dimension : dimensions_to_reduce) {
1851     if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
1852       return InvalidArgument(
1853           "attempting to reduce out-of-bounds dimension %lld in shape %s",
1854           dimension, ShapeUtil::HumanString(arg).c_str());
1855     }
1856   }
1857   TF_RETURN_IF_ERROR(
1858       VerifyReducerShape(to_apply, init_value, arg.element_type()));
1859 
1860   std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
1861                                            dimensions_to_reduce.end());
1862   std::vector<int64> new_dimensions;
1863   for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
1864     if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
1865       new_dimensions.push_back(arg.dimensions(i));
1866     }
1867   }
1868 
1869   return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
1870 }
1871 
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)1872 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
1873     const Shape& operand_shape, const Shape& init_value_shape,
1874     const Window& window, const ProgramShape& to_apply_shape) {
1875   TF_RETURN_IF_ERROR(
1876       ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
1877   TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
1878                                         operand_shape.element_type()));
1879   return InferWindowOutputShape(operand_shape, window,
1880                                 init_value_shape.element_type(),
1881                                 /*allow_negative_padding=*/false);
1882 }
1883 
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)1884 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
1885     const Shape& operand_shape, const ProgramShape& select_shape,
1886     const Window& window, const Shape& source_shape,
1887     const Shape& init_value_shape, const ProgramShape& scatter_shape) {
1888   TF_RETURN_IF_ERROR(
1889       ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
1890 
1891   // Check if the select function has a proper shape of (T,T) -> PRED.
1892   if (select_shape.parameters_size() != 2) {
1893     return InvalidArgument(
1894         "select function must take 2 parameters, but "
1895         "takes %d parameter(s).",
1896         select_shape.parameters_size());
1897   }
1898   const Shape& select_result_shape = select_shape.result();
1899   if (!ShapeUtil::Compatible(select_result_shape,
1900                              ShapeUtil::MakeShape(PRED, {}))) {
1901     return Unimplemented("select function must have rank-0 PRED result.");
1902   }
1903   const Shape& operand_element_shape =
1904       ShapeUtil::MakeShape(operand_shape.element_type(), {});
1905   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
1906                                                 select_shape.parameters(0))) {
1907     return InvalidArgument(
1908         "select function's first parameter shape currently must "
1909         "match the operand element shape. Got %s vs %s",
1910         ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
1911         ShapeUtil::HumanString(operand_element_shape).c_str());
1912   }
1913   if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
1914                                                 select_shape.parameters(1))) {
1915     return InvalidArgument(
1916         "select function's second parameter shape currently must "
1917         "match the operand element shape. Got %s vs %s",
1918         ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
1919         ShapeUtil::HumanString(operand_element_shape).c_str());
1920   }
1921 
1922   // Check if the scatter function has a proper shape as a reduction.
1923   TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
1924                                         source_shape.element_type()));
1925 
1926   // Check if the result shape of window operation matches the source shape.
1927   TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
1928                       InferWindowOutputShape(operand_shape, window,
1929                                              operand_shape.element_type(),
1930                                              /*allow_negative_padding=*/false));
1931   if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
1932                                                 window_result_shape)) {
1933     return InvalidArgument(
1934         "source shape does not match the shape of window-reduced operand: "
1935         "source(%s), window-reduced operand(%s)",
1936         ShapeUtil::HumanString(source_shape).c_str(),
1937         ShapeUtil::HumanString(window_result_shape).c_str());
1938   }
1939   return operand_shape;
1940 }
1941 
InferSliceShape(const Shape & arg,tensorflow::gtl::ArraySlice<int64> starts,tensorflow::gtl::ArraySlice<int64> limits,tensorflow::gtl::ArraySlice<int64> strides)1942 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
1943     const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
1944     tensorflow::gtl::ArraySlice<int64> limits,
1945     tensorflow::gtl::ArraySlice<int64> strides) {
1946   auto error = [&](const string& message) {
1947     return InvalidArgument(
1948         "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
1949         "{%s}; strides: {%s}",
1950         message.c_str(), ShapeUtil::HumanString(arg).c_str(),
1951         Join(starts, ",").c_str(), Join(limits, ",").c_str(),
1952         Join(strides, ",").c_str());
1953   };
1954   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
1955   VLOG(2) << tensorflow::strings::Printf(
1956       "slicing shape %s starts={%s} limits={%s}",
1957       ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
1958       Join(limits, ", ").c_str());
1959 
1960   if (starts.size() != limits.size()) {
1961     return error(Printf("slice start and limit sizes differ: %zu vs %zu",
1962                         starts.size(), limits.size()));
1963   }
1964 
1965   if (starts.size() != strides.size()) {
1966     return error(Printf("slice start and strides sizes differ: %zu vs %zu",
1967                         starts.size(), strides.size()));
1968   }
1969 
1970   if (starts.size() != ShapeUtil::Rank(arg)) {
1971     return InvalidArgument(
1972         "slice index count does not match argument rank: %zu vs %lld",
1973         starts.size(), ShapeUtil::Rank(arg));
1974   }
1975 
1976   std::vector<int64> sizes;
1977   for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
1978     int64 start_index = starts[dimension];
1979     int64 limit_index = limits[dimension];
1980     int64 stride = strides[dimension];
1981     if (start_index < 0) {
1982       return InvalidArgument("negative start index to slice: %lld",
1983                              start_index);
1984     }
1985     if (limit_index > arg.dimensions(dimension)) {
1986       return error(
1987           Printf("limit index (%lld) must be less than or equal to dimension "
1988                  "size (%lld)",
1989                  limit_index, arg.dimensions(dimension)));
1990     }
1991     VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
1992                                            start_index);
1993     VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
1994                                            limit_index);
1995     if (start_index > limit_index) {
1996       return error(
1997           Printf("limit index (%lld) must be greater or equal to "
1998                  "start index (%lld) in slice with positive stride",
1999                  limit_index, start_index));
2000     }
2001     if (stride <= 0) {
2002       return InvalidArgument("stride (%lld) must be positive", stride);
2003     }
2004     sizes.push_back((limit_index - start_index + stride - 1) / stride);
2005   }
2006 
2007   return ShapeUtil::MakeShape(arg.element_type(), sizes);
2008 }
2009 
InferDynamicSliceShape(const Shape & operand_shape,const Shape & start_indices_shape,tensorflow::gtl::ArraySlice<int64> slice_sizes)2010 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2011     const Shape& operand_shape, const Shape& start_indices_shape,
2012     tensorflow::gtl::ArraySlice<int64> slice_sizes) {
2013   TF_RETURN_IF_ERROR(
2014       ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
2015   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
2016                                             "start indices of dynamic slice"));
2017 
2018   VLOG(2) << tensorflow::strings::Printf(
2019       "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2020       ShapeUtil::HumanString(operand_shape).c_str(),
2021       ShapeUtil::HumanString(start_indices_shape).c_str(),
2022       Join(slice_sizes, ", ").c_str());
2023 
2024   if (ShapeUtil::Rank(start_indices_shape) != 1) {
2025     return InvalidArgument(
2026         "dynamic slice start indices of rank %lld must be rank1.",
2027         ShapeUtil::Rank(start_indices_shape));
2028   }
2029 
2030   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2031     return InvalidArgument(
2032         "dynamic slice start indices must be of integral type.");
2033   }
2034 
2035   const int64 start_num_dims = start_indices_shape.dimensions(0);
2036   if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
2037     return InvalidArgument(
2038         "dynamic slice start number of dimensions %lld (%s) must match rank "
2039         "%lld of slice input (%s)",
2040         start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
2041         ShapeUtil::Rank(operand_shape),
2042         ShapeUtil::HumanString(operand_shape).c_str());
2043   }
2044 
2045   if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
2046     return InvalidArgument(
2047         "dynamic slice index count does not match argument rank: %zu vs %lld",
2048         slice_sizes.size(), ShapeUtil::Rank(operand_shape));
2049   }
2050 
2051   for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2052     const int64 input_dim_size = operand_shape.dimensions(dim);
2053     const int64 slice_dim_size = slice_sizes[dim];
2054     if (slice_dim_size < 0) {
2055       return InvalidArgument("negative size index to dynamic slice: %lld",
2056                              slice_dim_size);
2057     }
2058     if (slice_dim_size > input_dim_size) {
2059       return InvalidArgument(
2060           "slice dim size %lld greater than dynamic slice dimension: %lld",
2061           slice_dim_size, input_dim_size);
2062     }
2063     VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
2064                                            slice_dim_size);
2065   }
2066 
2067   return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2068 }
2069 
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,const Shape & start_indices_shape)2070 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2071     const Shape& operand_shape, const Shape& update_shape,
2072     const Shape& start_indices_shape) {
2073   TF_RETURN_IF_ERROR(
2074       ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
2075   TF_RETURN_IF_ERROR(
2076       ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
2077   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
2078       start_indices_shape, "start indices of dynamic update slice"));
2079 
2080   VLOG(2) << tensorflow::strings::Printf(
2081       "updating slice of shape %s at dynamic start_indices %s with update "
2082       "shape %s",
2083       ShapeUtil::HumanString(operand_shape).c_str(),
2084       ShapeUtil::HumanString(start_indices_shape).c_str(),
2085       ShapeUtil::HumanString(update_shape).c_str());
2086 
2087   if (ShapeUtil::Rank(start_indices_shape) != 1) {
2088     return InvalidArgument(
2089         "dynamic update slice start indices of rank %lld must be rank1.",
2090         ShapeUtil::Rank(start_indices_shape));
2091   }
2092 
2093   if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2094     return InvalidArgument(
2095         "dynamic update slice start indices must be of integral type.");
2096   }
2097 
2098   const int64 start_num_dims = start_indices_shape.dimensions(0);
2099   if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
2100     return InvalidArgument(
2101         "dynamic slice start number of dimensions %lld (%s) must match rank "
2102         "%lld of slice input (%s)",
2103         start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
2104         ShapeUtil::Rank(operand_shape),
2105         ShapeUtil::HumanString(operand_shape).c_str());
2106   }
2107 
2108   if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
2109     return InvalidArgument(
2110         "dynamic update slice update rank does not match argument rank: "
2111         "%lld vs %lld",
2112         ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
2113   }
2114 
2115   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2116                                                      update_shape)) {
2117     return InvalidArgument(
2118         "dynamic update slice update element type does not match argument. "
2119         "operand.element_type: %s vs update.element_type: %s",
2120         PrimitiveType_Name(operand_shape.element_type()).c_str(),
2121         PrimitiveType_Name(update_shape.element_type()).c_str());
2122   }
2123 
2124   for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
2125     const int64 input_dim_size = operand_shape.dimensions(dim);
2126     const int64 update_dim_size = update_shape.dimensions(dim);
2127     if (update_dim_size < 0) {
2128       return InvalidArgument(
2129           "size index %lld to dynamic update slice must be >= 0",
2130           update_dim_size);
2131     }
2132     if (update_dim_size > input_dim_size) {
2133       return InvalidArgument(
2134           "update dim size %lld greater than dynamic slice dimension: %lld",
2135           update_dim_size, input_dim_size);
2136     }
2137     VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
2138                                            update_dim_size);
2139   }
2140 
2141   return operand_shape;
2142 }
2143 
InferReverseShape(const Shape & operand_shape,tensorflow::gtl::ArraySlice<int64> dimensions)2144 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2145     const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
2146   TF_RETURN_IF_ERROR(
2147       ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
2148   if (!AllUnique(dimensions)) {
2149     return InvalidArgument("a dimension number is duplicated in reverse");
2150   }
2151   for (int64 dimension : dimensions) {
2152     if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
2153       return InvalidArgument(
2154           "one of the reverse dimensions (%lld) is out-of-bounds in shape %s",
2155           dimension, ShapeUtil::HumanString(operand_shape).c_str());
2156     }
2157   }
2158   return operand_shape;
2159 }
2160 
InferGetTupleElementShape(const Shape & arg,int64 index)2161 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2162     const Shape& arg, int64 index) {
2163   if (!ShapeUtil::IsTuple(arg)) {
2164     return InvalidArgument(
2165         "cannot infer shape: attempting to index into non-tuple: %s",
2166         ShapeUtil::HumanString(arg).c_str());
2167   }
2168 
2169   if (index >= arg.tuple_shapes_size()) {
2170     return InvalidArgument(
2171         "cannot infer shape: attempt to index out of tuple bounds: %lld "
2172         ">= %d in shape %s",
2173         index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
2174   }
2175 
2176   return arg.tuple_shapes(index);
2177 }
2178 
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2179 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2180     const ProgramShape& condition, const ProgramShape& body,
2181     const Shape& init) {
2182   // Check the number of parameters for given computations.
2183   if (condition.parameters_size() != 1) {
2184     return InvalidArgument("condition must take 1 arguments; got %d",
2185                            condition.parameters_size());
2186   }
2187   if (body.parameters_size() != 1) {
2188     return InvalidArgument("body must take 1 arguments; got %d",
2189                            body.parameters_size());
2190   }
2191 
2192   auto shape_string = [&]() {
2193     return tensorflow::strings::Printf(
2194         "condition: %s; body: %s; init: %s",
2195         ShapeUtil::HumanString(condition).c_str(),
2196         ShapeUtil::HumanString(body).c_str(),
2197         ShapeUtil::HumanString(init).c_str());
2198   };
2199 
2200   // Check the shapes of computation parameters and return types.
2201   if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
2202     return InvalidArgument("condition must return a boolean; got %s",
2203                            shape_string().c_str());
2204   }
2205   if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2206       !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2207       !ShapeUtil::Compatible(body.result(), init)) {
2208     return InvalidArgument(
2209         "the parameter of condition and body, the result of the body, and init "
2210         "must all have the same shape; got %s",
2211         shape_string().c_str());
2212   }
2213 
2214   return init;
2215 }
2216 
InferConditionalShape(const Shape & predicate,const Shape & true_operand,const Shape & false_operand,const ProgramShape & true_computation,const ProgramShape & false_computation)2217 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2218     const Shape& predicate, const Shape& true_operand,
2219     const Shape& false_operand, const ProgramShape& true_computation,
2220     const ProgramShape& false_computation) {
2221   if (!ShapeUtil::ShapeIs(predicate, PRED, {})) {
2222     return InvalidArgument("predicate must be a boolean; got %s.",
2223                            ShapeUtil::HumanString(predicate).c_str());
2224   }
2225 
2226   if (true_computation.parameters_size() != 1) {
2227     return InvalidArgument("true_computation must take 1 argument; got %d.",
2228                            true_computation.parameters_size());
2229   }
2230   if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) {
2231     auto true_shape_string = [&]() {
2232       return tensorflow::strings::Printf(
2233           "true_operand: %s; true_computation: %s",
2234           ShapeUtil::HumanString(true_operand).c_str(),
2235           ShapeUtil::HumanString(true_computation).c_str());
2236     };
2237     return InvalidArgument(
2238         "true_operand must match the shape of the only parameter of "
2239         "true_computation: got %s.",
2240         true_shape_string().c_str());
2241   }
2242 
2243   if (false_computation.parameters_size() != 1) {
2244     return InvalidArgument("false_computation must take 1 argument; got %d.",
2245                            false_computation.parameters_size());
2246   }
2247   if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) {
2248     auto false_shape_string = [&]() {
2249       return tensorflow::strings::Printf(
2250           "false_operand: %s; false_computation: %s",
2251           ShapeUtil::HumanString(false_operand).c_str(),
2252           ShapeUtil::HumanString(false_computation).c_str());
2253     };
2254     return InvalidArgument(
2255         "false_operand must match the shape of the only parameter of "
2256         "false_computation: got %s.",
2257         false_shape_string().c_str());
2258   }
2259   if (!ShapeUtil::Compatible(true_computation.result(),
2260                              false_computation.result())) {
2261     auto shape_string = [&]() {
2262       return tensorflow::strings::Printf(
2263           "true_computation result: %s; false_computation result: %s.",
2264           ShapeUtil::HumanString(true_computation.result()).c_str(),
2265           ShapeUtil::HumanString(false_computation.result()).c_str());
2266     };
2267     return InvalidArgument(
2268         "the result of true_computation and false_computation must have the "
2269         "same shape: got %s.",
2270         shape_string().c_str());
2271   }
2272   return true_computation.result();
2273 }
2274 
InferBroadcastShape(const Shape & operand,tensorflow::gtl::ArraySlice<int64> broadcast_sizes)2275 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2276     const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
2277   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
2278   for (int64 size : broadcast_sizes) {
2279     if (size < 0) {
2280       return InvalidArgument("Broadcast with negative dimension size %lld.",
2281                              size);
2282     }
2283   }
2284 
2285   std::vector<int64> dimensions(operand.dimensions_size() +
2286                                 broadcast_sizes.size());
2287   std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2288   std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2289             dimensions.begin() + broadcast_sizes.size());
2290   return ShapeUtil::MakeShape(operand.element_type(), dimensions);
2291 }
2292 
InferReshapeShape(const Shape & operand,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<int64> new_sizes)2293 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2294     const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
2295     tensorflow::gtl::ArraySlice<int64> new_sizes) {
2296   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
2297 
2298   Shape inferred_shape =
2299       ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2300   VLOG(3) << "Reshape inferred shape: "
2301           << ShapeUtil::HumanString(inferred_shape);
2302 
2303   if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2304     return InvalidArgument(
2305         "reshape operation has mismatched element counts: from=%lld (%s) "
2306         "to=%lld (%s)",
2307         ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(),
2308         ShapeUtil::ElementsIn(inferred_shape),
2309         ShapeUtil::HumanString(inferred_shape).c_str());
2310   }
2311 
2312   std::vector<int64> indices(ShapeUtil::Rank(operand));
2313   std::iota(indices.begin(), indices.end(), 0);
2314   if (dimensions.size() != ShapeUtil::Rank(operand) ||
2315       !std::is_permutation(dimensions.begin(), dimensions.end(),
2316                            indices.begin())) {
2317     return InvalidArgument(
2318         "Reshape dimensions [%s] are not a permutation of the operand "
2319         "dimensions (operand shape is %s).",
2320         Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str());
2321   }
2322 
2323   return inferred_shape;
2324 }
2325 
InferTransposeShape(const Shape & operand,tensorflow::gtl::ArraySlice<int64> dimensions)2326 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
2327     const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
2328   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
2329 
2330   std::vector<int64> indices(ShapeUtil::Rank(operand));
2331   std::iota(indices.begin(), indices.end(), 0);
2332   if (dimensions.size() != ShapeUtil::Rank(operand) ||
2333       !std::is_permutation(dimensions.begin(), dimensions.end(),
2334                            indices.begin())) {
2335     return InvalidArgument(
2336         "Transpose dimensions not a permutation of the operand dimensions.");
2337   }
2338 
2339   // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
2340   // we need output[i]=input[dimensions[i]] which is
2341   // Permute(Inverse(dimensions),input).
2342   return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
2343 }
2344 
2345 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2346 // "degenerate" cases, as with binary elementwise ops.
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)2347 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
2348     const Shape& min, const Shape& operand, const Shape& max) {
2349   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
2350   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
2351   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
2352   if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
2353       !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
2354     return InvalidArgument("clamp op with different operand types: %s, %s, %s",
2355                            ShapeUtil::HumanString(min).c_str(),
2356                            ShapeUtil::HumanString(operand).c_str(),
2357                            ShapeUtil::HumanString(max).c_str());
2358   }
2359   if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
2360         ShapeUtil::IsScalar(min)) &&
2361        (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
2362         ShapeUtil::IsScalar(max)))) {
2363     return operand;
2364   }
2365   if (ShapeUtil::IsScalar(operand)) {
2366     if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
2367       return ShapeUtil::ChangeElementType(min, operand.element_type());
2368     } else if (ShapeUtil::IsScalar(min)) {
2369       return ShapeUtil::ChangeElementType(max, operand.element_type());
2370     } else if (ShapeUtil::IsScalar(max)) {
2371       return ShapeUtil::ChangeElementType(min, operand.element_type());
2372     }
2373   }
2374   return Unimplemented(
2375       "not yet implemented: %s, %s <clamp> %s", min.ShortDebugString().c_str(),
2376       max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
2377 }
2378 
2379 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2380 // "degenerate" cases, as with binary elementwise ops, as well as scalar
2381 // broadcast from all operands, not just the predicate.
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2382 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
2383     const Shape& pred, const Shape& on_true, const Shape& on_false) {
2384   bool compatible;
2385   if (ShapeUtil::IsTuple(on_true)) {
2386     // Select only defines the top-level buffer, so if it's a tuple, the two
2387     // input must match exactly.
2388     compatible = ShapeUtil::Compatible(on_true, on_false);
2389   } else {
2390     compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
2391   }
2392   if (!compatible) {
2393     return InvalidArgument(
2394         "operands to select must be the same shape; got %s and %s",
2395         ShapeUtil::HumanString(on_true).c_str(),
2396         ShapeUtil::HumanString(on_false).c_str());
2397   }
2398   if (pred.element_type() != PRED) {
2399     return InvalidArgument(
2400         "select's pred operand must have PRED element type; got %s",
2401         ShapeUtil::HumanString(pred).c_str());
2402   }
2403   if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) {
2404     // By this stage we know that pred's element type is PRED. Therefore, this
2405     // check restricts pred to be a PRED scalar, or a PRED array with the same
2406     // dimensions as on_true and on_false.
2407     return ShapeUtil::ChangeElementType(
2408         on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
2409   } else {
2410     return Unimplemented(
2411         "select operation with non-scalar predicate with dimensionality "
2412         " different from the other operands: %s",
2413         ShapeUtil::HumanString(pred).c_str());
2414   }
2415 }
2416 
InferCallShape(tensorflow::gtl::ArraySlice<const Shape * > arg_shapes,const ProgramShape & to_apply)2417 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
2418     tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
2419     const ProgramShape& to_apply) {
2420   // The applied function's arity equals the number of arguments.
2421   if (arg_shapes.size() != to_apply.parameters_size()) {
2422     string computation_signature = ShapeUtil::HumanString(to_apply);
2423     string argument_shapes =
2424         Join(arg_shapes, ", ", [](string* out, const Shape* shape) {
2425           tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape));
2426         });
2427     return InvalidArgument(
2428         "Call applied function arity must match number of arguments; got: "
2429         "arity: %d, arguments: %zu; computation signature: %s; argument "
2430         "shapes: [%s]",
2431         to_apply.parameters_size(), arg_shapes.size(),
2432         computation_signature.c_str(), argument_shapes.c_str());
2433   }
2434 
2435   // All arguments must be compatible with the program shape.
2436   for (int i = 0; i < arg_shapes.size(); ++i) {
2437     const Shape& arg_shape = *arg_shapes[i];
2438     const Shape& param_shape = to_apply.parameters(i);
2439     if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
2440       return InvalidArgument(
2441           "Call parameter must match argument; got parameter %d shape: %s, "
2442           "argument shape: %s",
2443           i, ShapeUtil::HumanString(param_shape).c_str(),
2444           ShapeUtil::HumanString(arg_shape).c_str());
2445     }
2446   }
2447 
2448   return to_apply.result();
2449 }
2450 
ValidateGatherDimensionNumbers(const Shape & input_shape,tensorflow::gtl::ArraySlice<int64> gather_indices_shape,const GatherDimensionNumbers & dim_numbers)2451 static Status ValidateGatherDimensionNumbers(
2452     const Shape& input_shape,
2453     tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
2454     const GatherDimensionNumbers& dim_numbers) {
2455   if (!c_is_sorted(dim_numbers.output_window_dims())) {
2456     return InvalidArgument(
2457         "Output window dimensions in gather op must be ascending; got: %s",
2458         Join(dim_numbers.output_window_dims(), ", ").c_str());
2459   }
2460 
2461   if (c_adjacent_find(dim_numbers.output_window_dims()) !=
2462       dim_numbers.output_window_dims().end()) {
2463     return InvalidArgument(
2464         "Output window dimensions in gather op must not repeat; got: %s",
2465         Join(dim_numbers.output_window_dims(), ", ").c_str());
2466   }
2467 
2468   const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
2469   const int64 output_shape_rank =
2470       output_window_dim_count + gather_indices_shape.size();
2471 
2472   for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
2473     int64 window_index = dim_numbers.output_window_dims(i);
2474     if (window_index < 0 || window_index >= output_shape_rank) {
2475       return InvalidArgument(
2476           "Window index %d in gather op is out of bounds; got %lld, but should "
2477           "have been in"
2478           "[0,%lld)",
2479           i, window_index, output_shape_rank);
2480     }
2481   }
2482 
2483   if (dim_numbers.gather_dims_to_operand_dims_size() !=
2484       gather_indices_shape.back()) {
2485     return InvalidArgument(
2486         "There must be exactly as many elements in gather_dims_to_operand_dims "
2487         "as there are elements in the last dimension of %%gather_indices; got: "
2488         "%d, expected %lld",
2489         dim_numbers.gather_dims_to_operand_dims_size(),
2490         gather_indices_shape.back());
2491   }
2492 
2493   for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
2494     int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
2495     if (gather_dim_to_input_dim < 0 ||
2496         gather_dim_to_input_dim >= input_shape.dimensions_size()) {
2497       return InvalidArgument(
2498           "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
2499           "got: %d->%lld",
2500           input_shape.dimensions_size(), i, gather_dim_to_input_dim);
2501     }
2502   }
2503 
2504   std::vector<int64> sorted_gather_dims_to_operand_dims(
2505       dim_numbers.gather_dims_to_operand_dims().begin(),
2506       dim_numbers.gather_dims_to_operand_dims().end());
2507 
2508   c_sort(sorted_gather_dims_to_operand_dims);
2509 
2510   if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
2511       sorted_gather_dims_to_operand_dims.end()) {
2512     return InvalidArgument(
2513         "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
2514         "got: %s",
2515         Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
2516   }
2517 
2518   for (int64 elided_dim : dim_numbers.elided_window_dims()) {
2519     if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
2520       return InvalidArgument(
2521           "Invalid elided_window_dims set in gather op; valid range is [0, "
2522           "%d), got: %lld",
2523           input_shape.dimensions_size(), elided_dim);
2524     }
2525   }
2526 
2527   if (!c_is_sorted(dim_numbers.elided_window_dims())) {
2528     return InvalidArgument(
2529         "elided_window_dims in gather op must be sorted; got: %s",
2530         Join(dim_numbers.elided_window_dims(), ", ").c_str());
2531   }
2532 
2533   if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
2534       dim_numbers.elided_window_dims().end()) {
2535     return InvalidArgument(
2536         "Repeated dimensions not allowed in elided_window_dims in gather op; "
2537         "got: %s",
2538         Join(dim_numbers.elided_window_dims(), ", ").c_str());
2539   }
2540 
2541   return Status::OK();
2542 }
2543 
InferGatherShape(const Shape & input_shape,const Shape & gather_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,tensorflow::gtl::ArraySlice<int64> window_bounds)2544 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
2545     const Shape& input_shape, const Shape& gather_indices_shape,
2546     const GatherDimensionNumbers& gather_dim_numbers,
2547     tensorflow::gtl::ArraySlice<int64> window_bounds) {
2548   TF_RETURN_IF_ERROR(
2549       ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
2550   TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
2551       gather_indices_shape, "gather indices operand of gather op"));
2552 
2553   if (gather_indices_shape.dimensions_size() < 1) {
2554     return InvalidArgument(
2555         "Gather indices parameter must at least of rank 1; got %s",
2556         ShapeUtil::HumanString(gather_indices_shape).c_str());
2557   }
2558 
2559   if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
2560     return InvalidArgument(
2561         "Gather indices parameter must be an integral tensor; got %s",
2562         ShapeUtil::HumanString(gather_indices_shape).c_str());
2563   }
2564 
2565   std::vector<int64> expanded_gather_indices_shape;
2566   // We implicitly reshape gather indices of shape P[N] to P[N,1].
2567   expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
2568   c_copy(gather_indices_shape.dimensions(),
2569          std::back_inserter(expanded_gather_indices_shape));
2570   if (expanded_gather_indices_shape.size() == 1) {
2571     expanded_gather_indices_shape.push_back(1);
2572   }
2573 
2574   TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
2575       input_shape, expanded_gather_indices_shape, gather_dim_numbers));
2576 
2577   if (window_bounds.size() != input_shape.dimensions_size()) {
2578     return InvalidArgument(
2579         "Gather op must have one window bound for every input dimension; got: "
2580         "len(window_bounds)=%lu, input_shape.rank=%d",
2581         window_bounds.size(), input_shape.dimensions_size());
2582   }
2583 
2584   if (window_bounds.size() !=
2585       gather_dim_numbers.output_window_dims_size() +
2586           gather_dim_numbers.elided_window_dims_size()) {
2587     return InvalidArgument(
2588         "All components of the window index in a gather op must either be a "
2589         "output window index or explicitly elided; got len(window_bounds)=%lu, "
2590         "output_window_bounds=%s, elided_window_bounds=%s",
2591         window_bounds.size(),
2592         Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
2593         Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
2594   }
2595 
2596   for (int i = 0; i < window_bounds.size(); i++) {
2597     int64 window_bound = window_bounds[i];
2598     int64 corresponding_input_bound = input_shape.dimensions(i);
2599     if (window_bound < 0 || window_bound > corresponding_input_bound) {
2600       return InvalidArgument(
2601           "Window bound at index %d in gather op is out of range, must be "
2602           "within "
2603           "[0, %lld), got %lld",
2604           i, corresponding_input_bound + 1, window_bound);
2605     }
2606   }
2607 
2608   for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
2609     if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
2610       return InvalidArgument(
2611           "Gather op can only elide window indices with bound 1, but bound is "
2612           "%lld for index %lld at position %d",
2613           window_bounds[gather_dim_numbers.elided_window_dims(i)],
2614           gather_dim_numbers.elided_window_dims(i), i);
2615     }
2616   }
2617 
2618   int64 result_rank = gather_dim_numbers.output_window_dims_size() +
2619                       (expanded_gather_indices_shape.size() - 1);
2620   int64 window_dims_seen = 0;
2621   int64 gather_dims_seen = 0;
2622   std::vector<int64> output_dim_bounds;
2623   output_dim_bounds.reserve(result_rank);
2624   for (int64 i = 0; i < result_rank; i++) {
2625     int64 current_bound;
2626     bool is_window_index =
2627         c_binary_search(gather_dim_numbers.output_window_dims(), i);
2628     if (is_window_index) {
2629       while (c_binary_search(gather_dim_numbers.elided_window_dims(),
2630                              window_dims_seen)) {
2631         window_dims_seen++;
2632       }
2633       current_bound = window_bounds[window_dims_seen++];
2634     } else {
2635       current_bound = expanded_gather_indices_shape[gather_dims_seen++];
2636     }
2637 
2638     output_dim_bounds.push_back(current_bound);
2639   }
2640 
2641   return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
2642 }
2643 
2644 }  // namespace xla
2645