1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <stddef.h>
19 #include <algorithm>
20 #include <numeric>
21 #include <set>
22 #include <string>
23
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/flatset.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/lib/strings/stringprintf.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/protobuf.h"
39
40 using tensorflow::str_util::Join;
41 using tensorflow::strings::Printf;
42
43 namespace xla {
44
45 namespace {
46
47 // Return the UnaryOperation proto enum value associated with the given HLO
48 // opcode.
OpcodeToUnaryOperation(HloOpcode opcode)49 UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
50 switch (opcode) {
51 case HloOpcode::kAbs:
52 return UNOP_ABS;
53 case HloOpcode::kCeil:
54 return UNOP_CEIL;
55 case HloOpcode::kCos:
56 return UNOP_COS;
57 case HloOpcode::kExp:
58 return UNOP_EXP;
59 case HloOpcode::kFloor:
60 return UNOP_FLOOR;
61 case HloOpcode::kImag:
62 return UNOP_IMAG;
63 case HloOpcode::kIsFinite:
64 return UNOP_IS_FINITE;
65 case HloOpcode::kLog:
66 return UNOP_LOG;
67 case HloOpcode::kNot:
68 return UNOP_NOT;
69 case HloOpcode::kNegate:
70 return UNOP_NEGATE;
71 case HloOpcode::kReal:
72 return UNOP_REAL;
73 case HloOpcode::kRoundNearestAfz:
74 return UNOP_ROUND_NEAREST_AFZ;
75 case HloOpcode::kSign:
76 return UNOP_SIGN;
77 case HloOpcode::kSin:
78 return UNOP_SIN;
79 case HloOpcode::kSort:
80 return UNOP_SORT;
81 case HloOpcode::kTanh:
82 return UNOP_TANH;
83 default:
84 LOG(FATAL) << "Unhandled opcode for conversion to unary operation: "
85 << opcode;
86 }
87 }
88
89 // Return the BinaryOperation proto enum value associated with the given HLO
90 // opcode.
OpcodeToBinaryOperation(HloOpcode opcode)91 BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
92 switch (opcode) {
93 case HloOpcode::kAtan2:
94 return BINOP_ATAN2;
95 case HloOpcode::kComplex:
96 return BINOP_COMPLEX;
97 case HloOpcode::kMultiply:
98 return BINOP_MUL;
99 case HloOpcode::kAdd:
100 return BINOP_ADD;
101 case HloOpcode::kSubtract:
102 return BINOP_SUB;
103 case HloOpcode::kDivide:
104 return BINOP_DIV;
105 case HloOpcode::kEq:
106 return BINOP_EQ;
107 case HloOpcode::kGe:
108 return BINOP_GE;
109 case HloOpcode::kGt:
110 return BINOP_GT;
111 case HloOpcode::kLe:
112 return BINOP_LE;
113 case HloOpcode::kLt:
114 return BINOP_LT;
115 case HloOpcode::kNe:
116 return BINOP_NE;
117 case HloOpcode::kMaximum:
118 return BINOP_MAX;
119 case HloOpcode::kMinimum:
120 return BINOP_MIN;
121 case HloOpcode::kPower:
122 return BINOP_POW;
123 case HloOpcode::kRemainder:
124 return BINOP_REM;
125 case HloOpcode::kOr:
126 return BINOP_OR;
127 case HloOpcode::kAnd:
128 return BINOP_AND;
129 case HloOpcode::kShiftLeft:
130 return BINOP_SHIFT_LEFT;
131 case HloOpcode::kShiftRightArithmetic:
132 return BINOP_SHIFT_RIGHT_ARITHMETIC;
133 case HloOpcode::kShiftRightLogical:
134 return BINOP_SHIFT_RIGHT_LOGICAL;
135 default:
136 LOG(FATAL) << "unhandled opcode " << opcode;
137 }
138 }
139
140 // Return the TernaryOperation proto enum value associated with the given HLO
141 // opcode.
OpcodeToTernaryOperation(HloOpcode opcode)142 TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) {
143 switch (opcode) {
144 case HloOpcode::kClamp:
145 return TRIOP_CLAMP;
146 case HloOpcode::kSelect:
147 return TRIOP_SELECT;
148 default:
149 LOG(FATAL) << "unhandled opcode " << opcode;
150 }
151 }
152
153 // Return the VariadicOperation proto enum value associated with the given HLO
154 // opcode.
OpcodeToVariadicOperation(HloOpcode opcode)155 VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) {
156 switch (opcode) {
157 case HloOpcode::kTuple:
158 return VAROP_TUPLE;
159 default:
160 LOG(FATAL) << "unhandled opcode " << opcode;
161 }
162 }
163
164 // Returns true if no element is present in slice more than once.
AllUnique(tensorflow::gtl::ArraySlice<int64> slice)165 bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
166 return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
167 }
168
ExpectNotTupleOrOpaque(const Shape & shape,tensorflow::StringPiece op_type)169 tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
170 tensorflow::StringPiece op_type) {
171 if (ShapeUtil::IsTuple(shape)) {
172 return InvalidArgument("Expected non-tuple argument for %s. Got: %s",
173 op_type.ToString().c_str(),
174 ShapeUtil::HumanString(shape).c_str());
175 } else if (ShapeUtil::IsOpaque(shape)) {
176 return InvalidArgument("Expected non-opaque argument for %s. Got: %s",
177 op_type.ToString().c_str(),
178 ShapeUtil::HumanString(shape).c_str());
179 } else {
180 return tensorflow::Status::OK();
181 }
182 }
183
VerifyReducerShape(const ProgramShape & reducer_shape,const Shape & init_value_shape,const PrimitiveType & input_element_type)184 tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
185 const Shape& init_value_shape,
186 const PrimitiveType& input_element_type) {
187 if (reducer_shape.parameters_size() != 2) {
188 return InvalidArgument(
189 "Reduction function must take 2 parameters, but "
190 "takes %d parameter(s).",
191 reducer_shape.parameters_size());
192 }
193
194 const Shape& accumulator_shape = reducer_shape.result();
195 if (ShapeUtil::Rank(accumulator_shape) != 0) {
196 return Unimplemented(
197 "Reduction function currently must have rank-0 result.");
198 }
199
200 // Check that the accumulator can be passed in as the first argument.
201 // Note: comparing here and below with Compatible since we don't care about
202 // layout in scalars - see b/26668201 for a longer-term vision.
203 if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
204 return InvalidArgument(
205 "Reduction function's first parameter shape differs from the "
206 "result shape: %s vs %s",
207 ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
208 ShapeUtil::HumanString(accumulator_shape).c_str());
209 }
210
211 // Check that init_value's shape is suitable for reducer_shape.
212 if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
213 init_value_shape)) {
214 return InvalidArgument(
215 "Reduction function's accumulator shape differs from the "
216 "init_value shape: %s vs %s",
217 ShapeUtil::HumanString(accumulator_shape).c_str(),
218 ShapeUtil::HumanString(init_value_shape).c_str());
219 }
220
221 // Check that the inputs can be passed in as the second argument.
222 const Shape& input_element_shape =
223 ShapeUtil::MakeShape(input_element_type, {});
224 if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
225 reducer_shape.parameters(1))) {
226 return InvalidArgument(
227 "Reduction function's second parameter shape differs from the "
228 "input type element type: %s vs %s",
229 ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
230 ShapeUtil::HumanString(input_element_shape).c_str());
231 }
232
233 // Currently the accumulator and inputs must be the same type,
234 // though that restriction could be relaxed.
235 if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
236 reducer_shape.parameters(1))) {
237 return InvalidArgument(
238 "Reduction function's second parameter shape currently must "
239 "match the result shape. Got %s vs %s",
240 ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
241 ShapeUtil::HumanString(accumulator_shape).c_str());
242 }
243
244 return tensorflow::Status::OK();
245 }
246
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)247 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
248 const Window& window,
249 PrimitiveType element_type,
250 bool allow_negative_padding) {
251 if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
252 return InvalidArgument(
253 "Window has dimension %d but base shape has dimension %lld.",
254 window.dimensions_size(), ShapeUtil::Rank(base_shape));
255 }
256
257 std::vector<int64> output_dimensions(window.dimensions_size());
258 for (int64 i = 0; i < window.dimensions_size(); ++i) {
259 const auto& dim = window.dimensions(i);
260 if (dim.size() <= 0) {
261 return InvalidArgument("Window has a non-positive dimension. Window: %s",
262 window.DebugString().c_str());
263 }
264 if (dim.stride() <= 0) {
265 return InvalidArgument("Window has a non-positive stride. Window: %s",
266 window.DebugString().c_str());
267 }
268 if (!allow_negative_padding && dim.padding_low() < 0) {
269 return InvalidArgument("Window has a negative low padding. Window: %s",
270 window.DebugString().c_str());
271 }
272 if (!allow_negative_padding && dim.padding_high() < 0) {
273 return InvalidArgument("Window has a negative high padding. Window: %s",
274 window.DebugString().c_str());
275 }
276 if (dim.base_dilation() < 1) {
277 return InvalidArgument(
278 "Window has a non-positive base area dilation factor. Window: %s",
279 window.DebugString().c_str());
280 }
281 if (dim.window_dilation() < 1) {
282 return InvalidArgument(
283 "Window has a non-positive window dilation factor. Window: %s",
284 window.DebugString().c_str());
285 }
286
287 const int64 dilated_base = window_util::DilatedBound(
288 ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
289 const int64 padded_dilated_base =
290 dim.padding_low() + dilated_base + dim.padding_high();
291 const int64 dilated_window =
292 window_util::DilatedBound(dim.size(), dim.window_dilation());
293
294 output_dimensions[i] = window_util::StridedBound(
295 padded_dilated_base, dilated_window, dim.stride());
296 }
297
298 return ShapeUtil::MakeShape(element_type, output_dimensions);
299 }
300
301 } // namespace
302
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)303 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
304 HloOpcode opcode, const HloInstruction* operand) {
305 // There is no copy operation at the proto level, so handle copy explicitly.
306 if (opcode == HloOpcode::kCopy) {
307 return operand->shape();
308 }
309
310 return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape());
311 }
312
InferUnaryOpShape(UnaryOperation operation,const Shape & arg)313 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
314 UnaryOperation operation, const Shape& arg) {
315 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
316
317 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg));
318 switch (operation) {
319 case UNOP_FLOOR:
320 case UNOP_CEIL:
321 if (!ShapeUtil::ElementIsFloating(arg)) {
322 return InvalidArgument(
323 "expected element type in shape to be floating for floor/ceil "
324 "operation; got %s",
325 PrimitiveType_Name(arg.element_type()).c_str());
326 }
327 return arg;
328 case UNOP_COS:
329 case UNOP_SIN:
330 case UNOP_EXP:
331 case UNOP_LOG:
332 case UNOP_TANH:
333 if (!ShapeUtil::ElementIsFloating(arg) &&
334 !ShapeUtil::ElementIsComplex(arg)) {
335 return InvalidArgument(
336 "expected element type in shape to be floating or complex for "
337 "sin/cos/exp/log/tanh operation; got %s",
338 PrimitiveType_Name(arg.element_type()).c_str());
339 }
340 return arg;
341 case UNOP_REAL:
342 case UNOP_IMAG:
343 if (!ShapeUtil::ElementIsComplex(arg)) {
344 return InvalidArgument(
345 "expected element type in shape to be complex for real/imag "
346 "operation; got %s",
347 PrimitiveType_Name(arg.element_type()).c_str());
348 }
349 return ShapeUtil::ChangeElementType(arg, F32);
350 case UNOP_ABS:
351 if (ShapeUtil::ElementIsComplex(arg)) {
352 return ShapeUtil::ChangeElementType(
353 arg, primitive_util::ComplexComponentType(arg.element_type()));
354 }
355 return arg;
356 case UNOP_NEGATE:
357 case UNOP_ROUND_NEAREST_AFZ:
358 case UNOP_SIGN:
359 case UNOP_SORT:
360 return arg;
361
362 case UNOP_NOT:
363 if (arg.element_type() != PRED &&
364 !primitive_util::IsIntegralType(arg.element_type())) {
365 return InvalidArgument(
366 "expected pred or an integral element type in argument to not "
367 "operation; got %s",
368 PrimitiveType_Name(arg.element_type()).c_str());
369 }
370 return arg;
371
372 case UNOP_IS_FINITE:
373 if (!ShapeUtil::ElementIsFloating(arg)) {
374 return InvalidArgument(
375 "expected element type in shape to be floating point for IsFinite "
376 "operation; got %s",
377 PrimitiveType_Name(arg.element_type()).c_str());
378 }
379 return ShapeUtil::ChangeElementType(arg, PRED);
380
381 default:
382 return InvalidArgument(
383 "Unknown operation for unary shape inference: \"%s\".",
384 UnaryOperation_Name(operation).c_str());
385 }
386 }
387
InferConcatOpShape(tensorflow::gtl::ArraySlice<const Shape * > arg_shapes,const int64 dimension)388 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
389 tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
390 const int64 dimension) {
391 if (arg_shapes.empty()) {
392 return InvalidArgument("Concatenate expects at least one argument");
393 }
394 if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
395 return InvalidArgument("dimension to concatenate along out of bounds: %lld",
396 dimension);
397 }
398 const Shape* arg_shape = nullptr;
399 PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
400 for (const Shape* shape : arg_shapes) {
401 TF_RETURN_IF_ERROR(
402 ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
403 if (!arg_shape) {
404 arg_shape = shape;
405 element_type = arg_shape->element_type();
406 continue;
407 }
408 if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
409 return InvalidArgument(
410 "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
411 "(%s)",
412 ShapeUtil::Rank(*arg_shape),
413 ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
414 ShapeUtil::HumanString(*shape).c_str());
415 }
416 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
417 return InvalidArgument(
418 "cannot concatenate arrays with different element types: %s vs %s",
419 PrimitiveType_Name(arg_shape->element_type()).c_str(),
420 PrimitiveType_Name(shape->element_type()).c_str());
421 }
422 for (int64 dimension_number = 0;
423 dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
424 if (arg_shape->dimensions(dimension_number) !=
425 shape->dimensions(dimension_number)) {
426 if (dimension_number == dimension) {
427 continue; // It's okay to differ in the dimension we're
428 // concatenating.
429 }
430 return InvalidArgument(
431 "cannot concatenate arrays that differ in dimensions other than "
432 "the one being concatenated (the other array dimensions must be "
433 "the same): %s vs %s in dimension %lld",
434 ShapeUtil::HumanString(*arg_shape).c_str(),
435 ShapeUtil::HumanString(*shape).c_str(), dimension);
436 }
437 }
438 element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
439 }
440
441 std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
442 arg_shape->dimensions().end());
443 for (size_t i = 1; i < arg_shapes.size(); ++i) {
444 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
445 }
446 return ShapeUtil::MakeShape(element_type, new_dimensions);
447 }
448
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)449 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
450 const Shape& operand_shape, PrimitiveType new_element_type) {
451 auto old_element_type = operand_shape.element_type();
452 if (primitive_util::IsComplexType(old_element_type) &&
453 !primitive_util::IsComplexType(new_element_type)) {
454 return Unimplemented(
455 "Unsupported conversion from complex to real type: %s => %s",
456 ShapeUtil::HumanString(operand_shape).c_str(),
457 PrimitiveType_Name(new_element_type).c_str());
458 }
459 if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
460 // Note: we may want to support tuple conversions via this operation in the
461 // future, by recursing into the tuple elements to check all sub-conversions
462 // are valid. For now we just reject them, though.
463 return InvalidArgument(
464 "cannot convert from or to tuple type; requested conversion: %s => %s",
465 ShapeUtil::HumanString(operand_shape).c_str(),
466 PrimitiveType_Name(new_element_type).c_str());
467 }
468
469 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
470 }
471
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)472 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
473 const Shape& operand_shape, PrimitiveType new_element_type) {
474 auto old_element_type = operand_shape.element_type();
475 if (primitive_util::IsComplexType(old_element_type) !=
476 primitive_util::IsComplexType(new_element_type)) {
477 return Unimplemented(
478 "Unsupported conversion between real and complex types: %s => %s",
479 ShapeUtil::HumanString(operand_shape).c_str(),
480 PrimitiveType_Name(new_element_type).c_str());
481 }
482 if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
483 // Note: we may want to support tuple conversions via this operation in the
484 // future, by recursing into the tuple elements to check all sub-conversions
485 // are valid. For now we just reject them, though.
486 return InvalidArgument(
487 "cannot convert from or to tuple type; requested conversion: %s => %s",
488 ShapeUtil::HumanString(operand_shape).c_str(),
489 PrimitiveType_Name(new_element_type).c_str());
490 }
491 if (primitive_util::BitWidth(old_element_type) !=
492 primitive_util::BitWidth(new_element_type)) {
493 return InvalidArgument(
494 "cannot bitcast types with different bit-widths: %s => %s",
495 PrimitiveType_Name(old_element_type).c_str(),
496 PrimitiveType_Name(new_element_type).c_str());
497 }
498
499 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
500 }
501
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)502 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
503 const Shape& operand_shape, const int exponent_bits,
504 const int mantissa_bits) {
505 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
506 return InvalidArgument(
507 "expected element type in shape to be floating point for "
508 "ReducePrecision operation; got %s",
509 PrimitiveType_Name(operand_shape.element_type()).c_str());
510 }
511 if (exponent_bits < 1) {
512 // One exponent bit is necessary to distinguish 0 from infinity. Having
513 // no exponent bits doesn't produce a sensible number, so we require at
514 // least one.
515 return InvalidArgument("expected exponent_bits >= 1; got %d",
516 exponent_bits);
517 }
518 if (mantissa_bits < 0) {
519 // A number with no mantissa bits is still meaningful, however.
520 return InvalidArgument("expected non-negative mantissa_bits; got %d",
521 mantissa_bits);
522 }
523 return operand_shape;
524 }
525
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)526 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
527 const Shape& operand_shape, const Shape& padding_value_shape,
528 const PaddingConfig& padding_config) {
529 if (ShapeUtil::IsTuple(operand_shape)) {
530 return InvalidArgument(
531 "pad operation does not support tuple-shape operands");
532 }
533 if (!ShapeUtil::IsScalar(padding_value_shape)) {
534 return InvalidArgument(
535 "pad operation does not support non-scalar padding values");
536 }
537 if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
538 return InvalidArgument(
539 "The rank of the operand and the padding configuration do not match: "
540 "%s vs %s",
541 ShapeUtil::HumanString(operand_shape).c_str(),
542 padding_config.ShortDebugString().c_str());
543 }
544 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
545 padding_value_shape)) {
546 return InvalidArgument(
547 "the element types of the operands to pad do not match");
548 }
549 std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
550 for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
551 dimensions[i] = operand_shape.dimensions(i) +
552 padding_config.dimensions(i).edge_padding_low() +
553 padding_config.dimensions(i).edge_padding_high() +
554 std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
555 padding_config.dimensions(i).interior_padding();
556 }
557 return ShapeUtil::MakeShape(
558 ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
559 dimensions);
560 }
561
562 // Current DotDimensionNumbers Requirements:
563 //
564 // Contracting Dimensions:
565 // *) Exactly one contracting dimension on both lhs and rhs.
566 // *) Contracting dimension size must be the same on both lhs and rhs.
567 // *) Contracting dimension numbers do not need to be the same (i.e. transposes
568 // are passed on to emitter implementations).
569 //
570 // Batch Dimensions:
571 // *) Same number of batch dimensions on both lhs and rhs.
572 // *) Same batch dimension numbers (and sizes) on both lhs and rhs.
573 // *) Batch dimension numbers must be ordered before contracting and
574 // non-contracting/non-batch dimension numbers.
575 //
576 // Non-Contracting-Non-Batch Dimensions:
577 // *) Can be 0 (matrix-vector) or 1 (matrix-matrix).
578 //
579
580 namespace {
581
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)582 Status ValidateDotDimensionNumbers(
583 const Shape& lhs, const Shape& rhs,
584 const DotDimensionNumbers& dimension_numbers) {
585 // Check that dimension numbers are in range.
586 auto dims_in_range =
587 [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
588 tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
589 auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
590 return std::all_of(contracting_dims.begin(), contracting_dims.end(),
591 in_range) &&
592 std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
593 };
594
595 tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
596 AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
597 tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
598 AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
599 tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
600 AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
601 tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
602 AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
603
604 if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
605 lhs_batch_dimensions) ||
606 !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
607 rhs_batch_dimensions)) {
608 return InvalidArgument("A dimension number is out of range in dot: %s",
609 dimension_numbers.DebugString().c_str());
610 }
611
612 // Check that dimension numbers are unique.
613 auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
614 tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
615 tensorflow::gtl::FlatSet<int64> dim_set;
616 auto is_unique = [&dim_set](int64 i) -> bool {
617 return dim_set.insert(i).second;
618 };
619 return std::all_of(contracting_dims.begin(), contracting_dims.end(),
620 is_unique) &&
621 std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
622 };
623
624 if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
625 !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
626 return InvalidArgument("A dimension number is not unique in dot: %s",
627 dimension_numbers.DebugString().c_str());
628 }
629
630 // Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
631 const int64 lhs_non_contracting_non_batch_dims =
632 ShapeUtil::Rank(lhs) -
633 dimension_numbers.lhs_contracting_dimensions_size() -
634 dimension_numbers.lhs_batch_dimensions_size();
635 const int64 rhs_non_contracting_non_batch_dims =
636 ShapeUtil::Rank(rhs) -
637 dimension_numbers.rhs_contracting_dimensions_size() -
638 dimension_numbers.rhs_batch_dimensions_size();
639 if (lhs_non_contracting_non_batch_dims < 0 ||
640 lhs_non_contracting_non_batch_dims > 1 ||
641 rhs_non_contracting_non_batch_dims < 0 ||
642 rhs_non_contracting_non_batch_dims > 1) {
643 return InvalidArgument(
644 "batch and contracting dimension number mismatch "
645 "with rank ");
646 }
647
648 // Check that batch dimension numbers are ordered before all others, and
649 // that they are monotonically increasing.
650 std::vector<int64> batch_dim_numbers(lhs_batch_dimensions.size());
651 std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0);
652 if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
653 lhs_batch_dimensions.begin()) ||
654 !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(),
655 rhs_batch_dimensions.begin())) {
656 return InvalidArgument(
657 "batch dimension numbers must precede non-batch dimensions and be"
658 "monotonically increasing.");
659 }
660
661 return Status::OK();
662 }
663
664 } // namespace
665
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)666 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
667 const Shape& lhs, const Shape& rhs,
668 const DotDimensionNumbers& dimension_numbers) {
669 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
670 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
671
672 auto fail = [lhs, rhs](const string& addendum) -> Status {
673 string message = tensorflow::strings::Printf(
674 "cannot infer shape for dot operation: %s <dot> %s",
675 ShapeUtil::HumanString(lhs).c_str(),
676 ShapeUtil::HumanString(rhs).c_str());
677 if (!addendum.empty()) {
678 message += ": " + addendum;
679 }
680 return InvalidArgument("%s", message.c_str());
681 };
682
683 // Check if both element types are the same.
684 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
685 return fail("element types do not match");
686 }
687
688 if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) {
689 return fail("dot only supports rank 1 or above.");
690 }
691
692 // Validate basic properties of dot dimension numbers.
693 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
694
695 // Check that there is only one contracting dimension for both lhs and rhs.
696 if (dimension_numbers.lhs_contracting_dimensions_size() !=
697 dimension_numbers.rhs_contracting_dimensions_size() ||
698 dimension_numbers.lhs_contracting_dimensions_size() != 1) {
699 return fail("must specify one contracting dimension for both lhs and rhs.");
700 }
701
702 // Check that contracting dimension sizes match.
703 const int64 lhs_contracting_dimension =
704 dimension_numbers.lhs_contracting_dimensions(0);
705 const int64 rhs_contracting_dimension =
706 dimension_numbers.rhs_contracting_dimensions(0);
707 if (lhs.dimensions(lhs_contracting_dimension) !=
708 rhs.dimensions(rhs_contracting_dimension)) {
709 return fail("contracting dimension sizes do not match.");
710 }
711
712 // Check that number of batch dimensions match.
713 if (dimension_numbers.lhs_batch_dimensions_size() !=
714 dimension_numbers.rhs_batch_dimensions_size()) {
715 return fail("must the same number of batch dimensions for lhs and rhs.");
716 }
717
718 // Check that batch dimension numbers and sizes match.
719 for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
720 if (dimension_numbers.lhs_batch_dimensions(i) !=
721 dimension_numbers.rhs_batch_dimensions(i) ||
722 lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
723 rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
724 return fail("batch dimension numbers and sizes must match for lhs/rhs.");
725 }
726 }
727
728 // The ranks of lhs and rhs are decremented by 1 respectively due to the
729 // contraction, and added for the rank of the result. When an input tensor is
730 // a scalar, its contribution to the rank of the result is 0.
731 // Generate the result dimensions in order, rhs dimensions followed by lhs
732 // dimensions except the contracted and batch dimensions.
733 std::vector<int64> dimensions;
734 std::unordered_set<int64> rhs_batch_dims(
735 dimension_numbers.rhs_batch_dimensions().begin(),
736 dimension_numbers.rhs_batch_dimensions().end());
737 for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
738 if (i != lhs_contracting_dimension) {
739 dimensions.push_back(lhs.dimensions(i));
740 }
741 }
742 for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
743 if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) {
744 dimensions.push_back(rhs.dimensions(i));
745 }
746 }
747 Shape result = ShapeUtil::MakeShape(
748 ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions);
749
750 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
751 VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
752 return result;
753 }
754
755 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(BinaryOperation operation,const Shape & lhs,const Shape & rhs)756 ShapeInference::InferDegenerateDimensionBroadcastShape(
757 BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
758 TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
759
760 // The shapes have to be compatible. That is, if some dimension d has a
761 // different size in the two shapes, one of them has to be 1 (a "degenerate"
762 // dimension). In that case, the output shape has the non-1 dimension size
763 // from the lhs/rhs pair in every index.
764 std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
765 for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
766 if (lhs.dimensions(i) == rhs.dimensions(i)) {
767 output_dimensions[i] = lhs.dimensions(i);
768 } else if (lhs.dimensions(i) == 1) {
769 output_dimensions[i] = rhs.dimensions(i);
770 } else if (rhs.dimensions(i) == 1) {
771 output_dimensions[i] = lhs.dimensions(i);
772 } else {
773 return InvalidArgument("binary op %s with incompatible shapes: %s and %s",
774 BinaryOperation_Name(operation).c_str(),
775 ShapeUtil::HumanString(lhs).c_str(),
776 ShapeUtil::HumanString(rhs).c_str());
777 }
778 }
779 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
780 output_dimensions);
781 }
782
InferInDimBroadcastShape(BinaryOperation operation,const Shape & smaller_shape,const Shape & larger_shape,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)783 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
784 BinaryOperation operation, const Shape& smaller_shape,
785 const Shape& larger_shape,
786 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
787 if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
788 // Reject "magic" inference for binops on different shapes, requiring
789 // the user to provide an explicit broadcast dimension in this case.
790 // See b/25177275 for more details.
791 return InvalidArgument("automatic shape inference not supported: %s and %s",
792 ShapeUtil::HumanString(smaller_shape).c_str(),
793 ShapeUtil::HumanString(larger_shape).c_str());
794 } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
795 return InvalidArgument(
796 "size of broadcast_dimensions has to match lower-rank operand's "
797 "rank; "
798 " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
799 "%zu",
800 ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
801 }
802
803 // broadcast_dimensions is a sequence of dimensions; its length is equal to
804 // the rank of the lower-rank operand. The lower-rank operand's dimensions
805 // have to be compatible with the higher-rank operand's dimensions at indices
806 // specified by broadcast_dimensions. Here compatible means the dimension
807 // sizes are equal or in one of the shapes the dimension size is
808 // one. Examples:
809 //
810 // smaller_shape larger_shape broadcast_dimensions output_shape
811 // [] [2, 3] {} [2, 3]
812 // [3] [4, 3] {1} [4, 3]
813 // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
814 // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
815 // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
816 //
817 // The column output_shape may not be the final shape of the XLA
818 // operation. After the "InDim" broadcasting implemented in this function
819 // expands the rank, degenerate-dimension broadcasting (implemented in
820 // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
821 // up to match the dimension size of the other operand. For example, consider
822 // the row in the table above with a smaller_shape of [2, 1]. The shape
823 // returned by this function is [2, 3, 1] (output_shape) however, the result
824 // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
825 // broadcasting.
826 //
827 // Invalid broadcasts:
828 //
829 // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
830 // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
831 // dimension zero of smaller_shape(size 3). **Zero here comes from the value
832 // in broadcast_dimensions.
833 //
834 // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
835 // Reason: Dimension one of larger_shape (size 3) is not compatible with
836 // dimension zero of smaller_shape(size 2)
837
838 // The output shape is initially the larger_shape. Sizes of dimensions
839 // specified in broadcast_dimensions are then changed to match the
840 // corresponding dimension size in smaller_shape.
841 Shape output_shape(larger_shape);
842 output_shape.set_element_type(
843 ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
844
845 for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
846 int64 dimension_to_match = broadcast_dimensions.at(i);
847 if (dimension_to_match < 0) {
848 return InvalidArgument(
849 "broadcast dimension number (%lld) cannot be negative",
850 dimension_to_match);
851 }
852 if (dimension_to_match >= larger_shape.dimensions_size()) {
853 return InvalidArgument(
854 "broadcast dimension number (%lld) too large; higher-rank "
855 "operand has rank %d",
856 dimension_to_match, larger_shape.dimensions_size());
857 }
858 int64 small_dimension_size = smaller_shape.dimensions(i);
859 int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
860 // Dimension sizes must be compatible: match or be degenerate (degenerate
861 // case is handled by degenerate dimension broadcasting which occurs after
862 // InDim broadcasting).
863 if (small_dimension_size != large_dimension_size &&
864 small_dimension_size != 1 && large_dimension_size != 1) {
865 return InvalidArgument(
866 "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i,
867 small_dimension_size, large_dimension_size,
868 ShapeUtil::HumanString(smaller_shape).c_str(),
869 ShapeUtil::HumanString(larger_shape).c_str());
870 }
871 // Make sure the broadcast dimensions are listed in a strictly increasing
872 // order.
873 if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
874 return InvalidArgument(
875 "broadcast dimensions order is wrong: %lld comes after %lld",
876 dimension_to_match, broadcast_dimensions.at(i - 1));
877 }
878
879 output_shape.set_dimensions(dimension_to_match, small_dimension_size);
880 }
881
882 return output_shape;
883 }
884
InferElementwiseBinaryOpShape(BinaryOperation operation,const Shape & lhs,const Shape & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)885 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
886 BinaryOperation operation, const Shape& lhs, const Shape& rhs,
887 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
888 TF_RETURN_IF_ERROR(
889 ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
890 TF_RETURN_IF_ERROR(
891 ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
892
893 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
894 return InvalidArgument(
895 "binary op %s with different element types: %s and %s",
896 BinaryOperation_Name(operation).c_str(),
897 ShapeUtil::HumanString(lhs).c_str(),
898 ShapeUtil::HumanString(rhs).c_str());
899 }
900
901 if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
902 std::vector<int64> identity_dims(ShapeUtil::Rank(lhs));
903 std::iota(identity_dims.begin(), identity_dims.end(), 0);
904 if (!broadcast_dimensions.empty() &&
905 broadcast_dimensions != identity_dims) {
906 return InvalidArgument(
907 "broadcast dimensions field must either be not set or be the "
908 "identity on binary operations with operands of the same rank");
909 }
910 }
911
912 if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
913 // If the shapes are the same other than layout, the output shape is the
914 // same (elementwise op).
915 return ShapeUtil::ChangeElementType(
916 lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
917 }
918
919 if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
920 return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
921 } else {
922 // Ranks do not match, so perform InDim broadcasting using
923 // broadcast_dimensions. Scalar broadcasting is a special case of this.
924 const Shape& larger_shape =
925 ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
926 const Shape& smaller_shape =
927 ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
928
929 // After InDim broadcasting, perform degenerate dimensions broadcasting.
930 TF_ASSIGN_OR_RETURN(
931 Shape indim_broadcast_shape,
932 InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
933 broadcast_dimensions));
934
935 return InferDegenerateDimensionBroadcastShape(
936 operation, indim_broadcast_shape, larger_shape);
937 }
938 }
939
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)940 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
941 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
942 return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(),
943 rhs->shape(), /*broadcast_dimensions=*/{});
944 }
945
InferBinaryOpShape(BinaryOperation operation,const Shape & lhs,const Shape & rhs,tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)946 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
947 BinaryOperation operation, const Shape& lhs, const Shape& rhs,
948 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
949 VLOG(2) << tensorflow::strings::Printf(
950 "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
951 BinaryOperation_Name(operation).c_str(),
952 ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
953 Join(broadcast_dimensions, ", ").c_str());
954 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
955 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
956
957 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
958 lhs, tensorflow::strings::StrCat("lhs of binary operation ",
959 BinaryOperation_Name(operation))));
960 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
961 rhs, tensorflow::strings::StrCat("rhs of binary operation ",
962 BinaryOperation_Name(operation))));
963 switch (operation) {
964 case BINOP_MAX:
965 case BINOP_MIN:
966 case BINOP_SUB:
967 case BINOP_ADD:
968 case BINOP_ATAN2:
969 case BINOP_POW:
970 case BINOP_DIV:
971 case BINOP_REM:
972 case BINOP_MUL:
973 case BINOP_SHIFT_LEFT:
974 case BINOP_SHIFT_RIGHT_ARITHMETIC:
975 case BINOP_SHIFT_RIGHT_LOGICAL:
976 return InferElementwiseBinaryOpShape(operation, lhs, rhs,
977 broadcast_dimensions);
978
979 case BINOP_COMPLEX: {
980 if (!ShapeUtil::ElementIsFloating(lhs)) {
981 return InvalidArgument(
982 "expected element type in shape to be floating for complex compose "
983 "operation; got %s",
984 PrimitiveType_Name(lhs.element_type()).c_str());
985 }
986 TF_ASSIGN_OR_RETURN(const Shape& shape,
987 InferElementwiseBinaryOpShape(operation, lhs, rhs,
988 broadcast_dimensions));
989 if (lhs.element_type() == F32 && rhs.element_type() == F32) {
990 return ShapeUtil::ChangeElementType(shape, C64);
991 } else {
992 return Unimplemented("complex component type not supported");
993 }
994 }
995 case BINOP_AND:
996 case BINOP_OR:
997 if (lhs.element_type() != PRED &&
998 !primitive_util::IsIntegralType(lhs.element_type())) {
999 return InvalidArgument(
1000 "expected pred or integral type in argument to and/or operation; "
1001 "got %s",
1002 PrimitiveType_Name(lhs.element_type()).c_str());
1003 }
1004 return InferElementwiseBinaryOpShape(operation, lhs, rhs,
1005 broadcast_dimensions);
1006 case BINOP_EQ:
1007 case BINOP_GE:
1008 case BINOP_GT:
1009 case BINOP_LE:
1010 case BINOP_LT:
1011 case BINOP_NE: {
1012 TF_ASSIGN_OR_RETURN(const Shape& shape,
1013 InferElementwiseBinaryOpShape(operation, lhs, rhs,
1014 broadcast_dimensions));
1015 return ShapeUtil::ChangeElementType(shape, PRED);
1016 }
1017 default:
1018 return Unimplemented(
1019 "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s",
1020 BinaryOperation_Name(operation).c_str(),
1021 lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
1022 }
1023 }
1024
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1025 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1026 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1027 const HloInstruction* ehs) {
1028 return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(),
1029 rhs->shape(), ehs->shape());
1030 }
1031
InferTernaryOpShape(TernaryOperation operation,const Shape & lhs,const Shape & rhs,const Shape & ehs)1032 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1033 TernaryOperation operation, const Shape& lhs, const Shape& rhs,
1034 const Shape& ehs) {
1035 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1036 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1037 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1038 switch (operation) {
1039 case TRIOP_CLAMP:
1040 return InferClampShape(lhs, rhs, ehs);
1041 case TRIOP_SELECT:
1042 return InferSelectShape(lhs, rhs, ehs);
1043 default:
1044 return InvalidArgument("unknown operation %s",
1045 TernaryOperation_Name(operation).c_str());
1046 }
1047 }
1048
InferVariadicOpShape(HloOpcode opcode,tensorflow::gtl::ArraySlice<const HloInstruction * > operands)1049 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1050 HloOpcode opcode,
1051 tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
1052 std::vector<const Shape*> operand_shapes;
1053 for (const HloInstruction* operand : operands) {
1054 operand_shapes.push_back(&operand->shape());
1055 }
1056 return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
1057 operand_shapes);
1058 }
1059
InferVariadicOpShape(VariadicOperation operation,tensorflow::gtl::ArraySlice<const Shape * > operand_shapes)1060 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1061 VariadicOperation operation,
1062 tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
1063 for (const Shape* shape : operand_shapes) {
1064 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1065 }
1066 switch (operation) {
1067 case VAROP_TUPLE: {
1068 Shape result = ShapeUtil::MakeTupleShape({});
1069 for (const Shape* shape : operand_shapes) {
1070 ShapeUtil::AppendShapeToTuple(*shape, &result);
1071 }
1072 return result;
1073 }
1074 default:
1075 return InvalidArgument("unknown operation %s",
1076 VariadicOperation_Name(operation).c_str());
1077 }
1078 }
1079
InferMapShape(tensorflow::gtl::ArraySlice<const Shape * > arg_shapes,const ProgramShape & to_apply,tensorflow::gtl::ArraySlice<int64> dimensions)1080 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1081 tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
1082 const ProgramShape& to_apply,
1083 tensorflow::gtl::ArraySlice<int64> dimensions) {
1084 if (arg_shapes.empty()) {
1085 return InvalidArgument("Map expects at least one argument");
1086 }
1087
1088 // All arguments must have the same shape.
1089 const Shape* arg_shape = arg_shapes[0];
1090 for (size_t i = 1; i < arg_shapes.size(); ++i) {
1091 TF_RETURN_IF_ERROR(
1092 ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
1093
1094 if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1095 continue;
1096 }
1097 if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
1098 !ShapeUtil::IsTuple(*arg_shape) &&
1099 ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1100 *arg_shape)) {
1101 if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1102 continue;
1103 }
1104 if (ShapeUtil::IsScalar(*arg_shape)) {
1105 arg_shape = arg_shapes[i];
1106 continue;
1107 }
1108 }
1109
1110 std::vector<string> pieces;
1111 for (const Shape* shape : arg_shapes) {
1112 pieces.push_back(ShapeUtil::HumanString(*shape));
1113 }
1114 return InvalidArgument(
1115 "Map operation requires all operands to have the same shape; got: "
1116 "%s",
1117 Join(pieces, ", ").c_str());
1118 }
1119
1120 // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1121 // only support mapping across all dimensions: i.e. scalar map functions).
1122 if (dimensions.size() != arg_shape->dimensions_size()) {
1123 return InvalidArgument(
1124 "Map applied to a subset of dimensions currently not supported: "
1125 "arg_dimension_size: %d, requested_map_dimensions_size: %zu",
1126 arg_shape->dimensions_size(), dimensions.size());
1127 }
1128
1129 // Check that requested map dimensions numbers are monotonically increasing.
1130 for (int i = 0; i < dimensions.size(); ++i) {
1131 if (dimensions[i] != i) {
1132 return InvalidArgument(
1133 "Map requires monotonically increasing dimension numbers, found: %s ",
1134 Join(dimensions, ", ").c_str());
1135 }
1136 }
1137
1138 // The applied function's arity equals the number of arguments.
1139 if (arg_shapes.size() != to_apply.parameters_size()) {
1140 return InvalidArgument(
1141 "Map applied function arity must match number of arguments; got: "
1142 "arity: %d, arguments: %zu",
1143 to_apply.parameters_size(), arg_shapes.size());
1144 }
1145
1146 // The parameters should all be scalars, and the output too.
1147 const Shape& output_shape = to_apply.result();
1148 if (!ShapeUtil::IsScalar(output_shape)) {
1149 return InvalidArgument(
1150 "mapped computation's result has to be a scalar; "
1151 "got: %s",
1152 ShapeUtil::HumanString(output_shape).c_str());
1153 }
1154
1155 for (int i = 0; i < to_apply.parameters_size(); ++i) {
1156 const Shape& parameter_shape = to_apply.parameters(i);
1157
1158 if (!ShapeUtil::IsScalar(parameter_shape)) {
1159 return InvalidArgument(
1160 "mapped computation's parameter has to be a scalar; "
1161 "got parameter %d shape: %s",
1162 i, ShapeUtil::HumanString(parameter_shape).c_str());
1163 }
1164
1165 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1166 *arg_shape)) {
1167 return InvalidArgument(
1168 "mapped computation's parameter type has to match argument element "
1169 "type; got parameter %d shape: %s, argument shape: %s",
1170 i, ShapeUtil::HumanString(parameter_shape).c_str(),
1171 ShapeUtil::HumanString(*arg_shape).c_str());
1172 }
1173 }
1174
1175 return ShapeUtil::MakeShape(output_shape.element_type(),
1176 AsInt64Slice(arg_shape->dimensions()));
1177 }
1178
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1179 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1180 const Shape& operand_shape, const Shape& scale_shape,
1181 const Shape& offset_shape, int64 feature_index) {
1182 TF_RETURN_IF_ERROR(
1183 ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
1184 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1185 offset_shape, "offset input of batch norm training"));
1186 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1187 scale_shape, "scale input of batch norm training"));
1188
1189 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1190 tensorflow::Status::OK());
1191 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1192 tensorflow::Status::OK());
1193 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1194 tensorflow::Status::OK());
1195
1196 if (feature_index >= ShapeUtil::Rank(operand_shape)) {
1197 return InvalidArgument(
1198 "Expected feature_index of batch-norm-training to be "
1199 "smaller than the rank of operand_shape; "
1200 "got feature_index %lld, and rank %lld",
1201 feature_index, ShapeUtil::Rank(operand_shape));
1202 }
1203
1204 if (feature_index < 0) {
1205 return InvalidArgument(
1206 "Expected feature_index of batch-norm-training to "
1207 "be a non-negative number, got %lld",
1208 feature_index);
1209 }
1210
1211 if (ShapeUtil::Rank(operand_shape) < 1) {
1212 return InvalidArgument(
1213 "Expected the rank of operand to "
1214 "batch-norm-training to be at least 1; got %lld",
1215 ShapeUtil::Rank(operand_shape));
1216 }
1217
1218 if (ShapeUtil::Rank(offset_shape) != 1) {
1219 return InvalidArgument(
1220 "Offset input of batch-norm-training must have"
1221 " rank 1, but has rank %lld.",
1222 ShapeUtil::Rank(offset_shape));
1223 }
1224
1225 if (ShapeUtil::Rank(scale_shape) != 1) {
1226 return InvalidArgument(
1227 "Scale input of batch-norm-training must have"
1228 " rank 1, but has rank %lld.",
1229 ShapeUtil::Rank(scale_shape));
1230 }
1231
1232 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1233 return InvalidArgument(
1234 "The operand to batch-norm-training must have a floating point "
1235 "element type, but the shape is %s",
1236 PrimitiveType_Name(operand_shape.element_type()).c_str());
1237 }
1238
1239 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1240 operand_shape)) {
1241 return InvalidArgument(
1242 "The inputs should have the same element type for batch-norm-training, "
1243 "but the shape of offset factor is %s "
1244 "and the shape of operand is %s",
1245 PrimitiveType_Name(offset_shape.element_type()).c_str(),
1246 PrimitiveType_Name(operand_shape.element_type()).c_str());
1247 }
1248
1249 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1250 operand_shape)) {
1251 return InvalidArgument(
1252 "The inputs should have the same element type for batch-norm-training, "
1253 "but the shape of scale factor is %s "
1254 "and the shape of operand is %s",
1255 PrimitiveType_Name(scale_shape.element_type()).c_str(),
1256 PrimitiveType_Name(operand_shape.element_type()).c_str());
1257 }
1258
1259 const int64 feature_count = operand_shape.dimensions(feature_index);
1260 Shape output_shape_for_mean_and_var =
1261 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1262
1263 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1264 return InvalidArgument(
1265 "The size of offset factor should be the same as feature count,"
1266 "but the size of offset factor is %lld "
1267 "and the feature count is %lld",
1268 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1269 }
1270
1271 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1272 return InvalidArgument(
1273 "The size of scale factor should be the same as feature count,"
1274 "but the size of scale factor is %lld "
1275 "and the feature count is %lld",
1276 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1277 }
1278
1279 return ShapeUtil::MakeTupleShape({operand_shape,
1280 output_shape_for_mean_and_var,
1281 output_shape_for_mean_and_var});
1282 }
1283
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1284 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1285 const Shape& operand_shape, const Shape& scale_shape,
1286 const Shape& offset_shape, const Shape& mean_shape,
1287 const Shape& variance_shape, int64 feature_index) {
1288 TF_RETURN_IF_ERROR(
1289 ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
1290 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1291 offset_shape, "offset input of batch norm inference"));
1292 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1293 scale_shape, "scale input of batch norm inference"));
1294
1295 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1296 tensorflow::Status::OK());
1297 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1298 tensorflow::Status::OK());
1299 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1300 tensorflow::Status::OK());
1301 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) ==
1302 tensorflow::Status::OK());
1303 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
1304 tensorflow::Status::OK());
1305
1306 if (feature_index >= ShapeUtil::Rank(operand_shape)) {
1307 return InvalidArgument(
1308 "Expected feature_index of batch-norm-inference to be "
1309 "smaller than the rank of operand_shape; "
1310 "got feature_index %lld, and rank %lld",
1311 feature_index, ShapeUtil::Rank(operand_shape));
1312 }
1313
1314 if (feature_index < 0) {
1315 return InvalidArgument(
1316 "Expected feature_index of batch-norm-inference to "
1317 "be a non-negative number, got %lld",
1318 feature_index);
1319 }
1320
1321 if (ShapeUtil::Rank(operand_shape) < 1) {
1322 return InvalidArgument(
1323 "Expected the rank of operand to "
1324 "batch-norm-inference to be at least 1; got %lld",
1325 ShapeUtil::Rank(operand_shape));
1326 }
1327
1328 if (ShapeUtil::Rank(offset_shape) != 1) {
1329 return InvalidArgument(
1330 "Offset input of batch-norm-inference must have"
1331 " rank 1, but has rank %lld.",
1332 ShapeUtil::Rank(offset_shape));
1333 }
1334
1335 if (ShapeUtil::Rank(scale_shape) != 1) {
1336 return InvalidArgument(
1337 "Scale input of batch-norm-inference must have"
1338 " rank 1, but has rank %lld.",
1339 ShapeUtil::Rank(scale_shape));
1340 }
1341
1342 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1343 return InvalidArgument(
1344 "The operand to batch-norm-inference must have a floating point "
1345 "element type, but the shape is %s",
1346 PrimitiveType_Name(operand_shape.element_type()).c_str());
1347 }
1348
1349 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1350 operand_shape)) {
1351 return InvalidArgument(
1352 "The inputs should have the same element type for "
1353 "batch-norm-inference, "
1354 "but the shape of offset factor is %s "
1355 "and the shape of operand is %s",
1356 PrimitiveType_Name(offset_shape.element_type()).c_str(),
1357 PrimitiveType_Name(operand_shape.element_type()).c_str());
1358 }
1359
1360 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1361 operand_shape)) {
1362 return InvalidArgument(
1363 "The inputs should have the same element type for "
1364 "batch-norm-inference, "
1365 "but the shape of scale factor is %s "
1366 "and the shape of operand is %s",
1367 PrimitiveType_Name(scale_shape.element_type()).c_str(),
1368 PrimitiveType_Name(operand_shape.element_type()).c_str());
1369 }
1370
1371 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1372 operand_shape)) {
1373 return InvalidArgument(
1374 "The inputs should have the same element type for "
1375 "batch-norm-inference, "
1376 "but the shape of mean is %s "
1377 "and the shape of operand is %s",
1378 PrimitiveType_Name(mean_shape.element_type()).c_str(),
1379 PrimitiveType_Name(operand_shape.element_type()).c_str());
1380 }
1381
1382 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1383 operand_shape)) {
1384 return InvalidArgument(
1385 "The inputs should have the same element type for "
1386 "batch-norm-inference, "
1387 "but the shape of variance is %s "
1388 "and the shape of operand is %s",
1389 PrimitiveType_Name(mean_shape.element_type()).c_str(),
1390 PrimitiveType_Name(variance_shape.element_type()).c_str());
1391 }
1392
1393 const int64 feature_count = operand_shape.dimensions(feature_index);
1394 Shape output_shape_for_mean_and_var =
1395 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1396
1397 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1398 return InvalidArgument(
1399 "The size of offset factor should be the same as feature count,"
1400 "but the size of offset factor is %lld "
1401 "and the feature count is %lld",
1402 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1403 }
1404
1405 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1406 return InvalidArgument(
1407 "The size of scale factor should be the same as feature count,"
1408 "but the size of scale factor is %lld "
1409 "and the feature count is %lld",
1410 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1411 }
1412
1413 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1414 return InvalidArgument(
1415 "The size of mean should be the same as feature count,"
1416 "but the size of mean is %lld "
1417 "and the feature count is %lld",
1418 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1419 }
1420
1421 if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1422 return InvalidArgument(
1423 "The size of variance should be the same as feature count,"
1424 "but the size of variance is %lld "
1425 "and the feature count is %lld",
1426 ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1427 }
1428
1429 return operand_shape;
1430 }
1431
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1432 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1433 const Shape& operand_shape, const Shape& scale_shape,
1434 const Shape& mean_shape, const Shape& var_shape,
1435 const Shape& output_grad_shape, int64 feature_index) {
1436 TF_RETURN_IF_ERROR(
1437 ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
1438 TF_RETURN_IF_ERROR(
1439 ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
1440 TF_RETURN_IF_ERROR(
1441 ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
1442 TF_RETURN_IF_ERROR(
1443 ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
1444 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
1445 output_grad_shape, "output_grad input of batch norm grad"));
1446
1447 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1448 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1449 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1450 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1451 TF_RETURN_IF_ERROR(
1452 ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1453
1454 if (feature_index >= ShapeUtil::Rank(operand_shape)) {
1455 return InvalidArgument(
1456 "Expected feature_index of batch-norm-grad to be "
1457 "smaller than the rank of operand_shape; "
1458 "got feature_index %lld, and rank %lld",
1459 feature_index, ShapeUtil::Rank(operand_shape));
1460 }
1461
1462 if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) {
1463 return InvalidArgument(
1464 "Expected operand_shape of batch-norm-grad to have the same rank as"
1465 " output_grad_shape; got rank(oprand_shape) %lld, and"
1466 " rank(output_grad_shape) %lld",
1467 ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape));
1468 }
1469
1470 if (ShapeUtil::Rank(mean_shape) != 1) {
1471 return InvalidArgument(
1472 "Mean input of batch-norm-grad must have"
1473 " rank 1, but has rank %lld.",
1474 ShapeUtil::Rank(mean_shape));
1475 }
1476
1477 if (ShapeUtil::Rank(scale_shape) != 1) {
1478 return InvalidArgument(
1479 "Scale input of batch-norm-grad must have"
1480 " rank 1, but has rank %lld.",
1481 ShapeUtil::Rank(scale_shape));
1482 }
1483
1484 if (ShapeUtil::Rank(var_shape) != 1) {
1485 return InvalidArgument(
1486 "Var input of batch-norm-grad must have"
1487 " rank 1, but has rank %lld.",
1488 ShapeUtil::Rank(var_shape));
1489 }
1490
1491 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1492 return InvalidArgument(
1493 "The operand to batch-norm-grad must have a floating point "
1494 "element type, but the shape is %s",
1495 PrimitiveType_Name(operand_shape.element_type()).c_str());
1496 }
1497
1498 if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1499 return InvalidArgument(
1500 "The output_grad to batch-norm-grad must have a floating point "
1501 "element type, but the shape is %s",
1502 PrimitiveType_Name(output_grad_shape.element_type()).c_str());
1503 }
1504
1505 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1506 operand_shape)) {
1507 return InvalidArgument(
1508 "The inputs should have the same element type for batch-norm-grad, "
1509 "but the element type of output_grad is %s "
1510 "and the element type of operand is %s",
1511 PrimitiveType_Name(output_grad_shape.element_type()).c_str(),
1512 PrimitiveType_Name(operand_shape.element_type()).c_str());
1513 }
1514
1515 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1516 operand_shape)) {
1517 return InvalidArgument(
1518 "The inputs should have the same element type for batch-norm-grad, "
1519 "but the element type of scale factor is %s "
1520 "and the element type of operand is %s",
1521 PrimitiveType_Name(scale_shape.element_type()).c_str(),
1522 PrimitiveType_Name(operand_shape.element_type()).c_str());
1523 }
1524
1525 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1526 operand_shape)) {
1527 return InvalidArgument(
1528 "The inputs should have the same element type for batch-norm-grad, "
1529 "but the element type of mean is %s "
1530 "and the element type of operand is %s",
1531 PrimitiveType_Name(mean_shape.element_type()).c_str(),
1532 PrimitiveType_Name(operand_shape.element_type()).c_str());
1533 }
1534
1535 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1536 operand_shape)) {
1537 return InvalidArgument(
1538 "The inputs should have the same element type for batch-norm-grad, "
1539 "but the element type of mean is %s "
1540 "and the element type of operand is %s",
1541 PrimitiveType_Name(mean_shape.element_type()).c_str(),
1542 PrimitiveType_Name(operand_shape.element_type()).c_str());
1543 }
1544
1545 const int64 feature_count = operand_shape.dimensions(feature_index);
1546
1547 Shape feature_shape =
1548 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1549
1550 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1551 return InvalidArgument(
1552 "The size of mean should be the same as feature count,"
1553 "but the size of offset factor is %lld "
1554 "and the feature count is %lld",
1555 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1556 }
1557
1558 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1559 return InvalidArgument(
1560 "The size of scale factor should be the same as feature count,"
1561 "but the size of scale factor is %lld "
1562 "and the feature count is %lld",
1563 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1564 }
1565
1566 if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1567 return InvalidArgument(
1568 "The size of variance should be the same as feature count,"
1569 "but the size of variance is %lld "
1570 "and the feature count is %lld",
1571 ShapeUtil::GetDimension(var_shape, 0), feature_count);
1572 }
1573
1574 // Verify operand_shape and output_grad_shape have same bounds.
1575 for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
1576 if (ShapeUtil::GetDimension(operand_shape, i) !=
1577 ShapeUtil::GetDimension(output_grad_shape, i)) {
1578 return InvalidArgument(
1579 "The bounds of operand shape should be the same as output_grad's,"
1580 "but the bound of operand_shape at dimension %lld is %lld "
1581 "and the bound of output_grad_shape is %lld",
1582 i, ShapeUtil::GetDimension(operand_shape, i),
1583 ShapeUtil::GetDimension(output_grad_shape, i));
1584 }
1585 }
1586
1587 return ShapeUtil::MakeTupleShape(
1588 {operand_shape, feature_shape, feature_shape});
1589 }
1590
InferConvolveShape(const Shape & lhs,const Shape & rhs,const Window & window,const ConvolutionDimensionNumbers & dnums)1591 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1592 const Shape& lhs, const Shape& rhs, const Window& window,
1593 const ConvolutionDimensionNumbers& dnums) {
1594 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
1595 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
1596
1597 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
1598 return InvalidArgument(
1599 "Convolution with different element types: %s and %s",
1600 ShapeUtil::HumanString(lhs).c_str(),
1601 ShapeUtil::HumanString(rhs).c_str());
1602 }
1603 if (dnums.input_spatial_dimensions_size() !=
1604 dnums.kernel_spatial_dimensions_size()) {
1605 return InvalidArgument(
1606 "Both arguments to convolution must have same number of dimensions.\n"
1607 "Window: %s",
1608 window.DebugString().c_str());
1609 }
1610
1611 const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1612 if (window.dimensions_size() != num_spatial_dims) {
1613 return InvalidArgument(
1614 "Window must have same number of dimensions as dimension numbers.\n"
1615 "Window: %s\nDimension numbers: %s",
1616 window.DebugString().c_str(), dnums.DebugString().c_str());
1617 }
1618
1619 const int num_dims = num_spatial_dims + 2;
1620 if (ShapeUtil::Rank(lhs) != num_dims) {
1621 return InvalidArgument(
1622 "The LHS argument to a convolution should have rank %d.\n"
1623 "lhs: %s",
1624 num_dims, ShapeUtil::HumanString(lhs).c_str());
1625 }
1626 if (ShapeUtil::Rank(rhs) != num_dims) {
1627 return InvalidArgument(
1628 "The RHS argument to a convolution should have rank %d.\n"
1629 "lhs: %s",
1630 num_dims, ShapeUtil::HumanString(lhs).c_str());
1631 }
1632 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1633 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1634
1635 // Verifies that the input and window dimensions are a permutation of
1636 // the dimension numbers.
1637 std::vector<int64> input_dnums(num_dims);
1638 input_dnums[0] = dnums.input_batch_dimension();
1639 input_dnums[1] = dnums.input_feature_dimension();
1640 std::copy(dnums.input_spatial_dimensions().begin(),
1641 dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1642 std::sort(input_dnums.begin(), input_dnums.end());
1643
1644 std::vector<int64> window_dnums(num_dims);
1645 window_dnums[0] = dnums.kernel_input_feature_dimension();
1646 window_dnums[1] = dnums.kernel_output_feature_dimension();
1647 std::copy(dnums.kernel_spatial_dimensions().begin(),
1648 dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1649 std::sort(window_dnums.begin(), window_dnums.end());
1650
1651 std::vector<int64> output_dnums(num_dims);
1652 output_dnums[0] = dnums.output_batch_dimension();
1653 output_dnums[1] = dnums.output_feature_dimension();
1654 std::copy(dnums.output_spatial_dimensions().begin(),
1655 dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1656 std::sort(output_dnums.begin(), output_dnums.end());
1657
1658 std::vector<int64> expected_dnums(num_dims);
1659 std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1660
1661 const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1662 if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
1663 !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
1664 !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
1665 return InvalidArgument(
1666 "A dimension number is out of range in convolution: %s",
1667 dnums.DebugString().c_str());
1668 }
1669
1670 if (input_dnums != expected_dnums) {
1671 return InvalidArgument(
1672 "Input dimensions of convolution must contain each dimension exactly "
1673 "once: %s",
1674 dnums.DebugString().c_str());
1675 }
1676 if (window_dnums != expected_dnums) {
1677 return InvalidArgument(
1678 "Window dimensions of convolution must contain each dimension exactly "
1679 "once: %s",
1680 dnums.DebugString().c_str());
1681 }
1682 if (output_dnums != expected_dnums) {
1683 return InvalidArgument(
1684 "Output dimensions of convolution must contain each dimension exactly "
1685 "once: %s",
1686 dnums.DebugString().c_str());
1687 }
1688
1689 std::vector<int64> input_spatial_dims(num_spatial_dims);
1690 for (int i = 0; i < num_spatial_dims; ++i) {
1691 input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1692 }
1693 const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1694 const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1695
1696 std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1697 for (int i = 0; i < num_spatial_dims; ++i) {
1698 kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1699 }
1700 const int64 kernel_input_features =
1701 rhs.dimensions(dnums.kernel_input_feature_dimension());
1702 const int64 kernel_output_features =
1703 rhs.dimensions(dnums.kernel_output_feature_dimension());
1704
1705 if (input_features != kernel_input_features) {
1706 return InvalidArgument(
1707 "Expected LHS feature dimension (value %lld) to match RHS "
1708 "input feature dimension (value %lld); got <conv>(%s, %s)\n"
1709 "Dimension numbers: {%s}",
1710 input_features, kernel_input_features,
1711 ShapeUtil::HumanString(lhs).c_str(),
1712 ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
1713 }
1714 std::vector<int64> window_dims(num_spatial_dims);
1715 for (int i = 0; i < num_spatial_dims; ++i) {
1716 window_dims[i] = window.dimensions(i).size();
1717 }
1718 if (kernel_spatial_dims != window_dims) {
1719 return InvalidArgument(
1720 "Window dimensions do not match RHS shape:\n\t"
1721 "RHS shape: %s\n\t"
1722 "Window: {%s}\n\t"
1723 "Dimension numbers: {%s}",
1724 ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
1725 dnums.ShortDebugString().c_str());
1726 }
1727
1728 Shape base_shape =
1729 ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1730 TF_ASSIGN_OR_RETURN(
1731 Shape window_output_shape,
1732 InferWindowOutputShape(base_shape, window, lhs.element_type(),
1733 /*allow_negative_padding=*/true));
1734
1735 std::vector<int64> dimensions(num_dims);
1736 dimensions[dnums.output_batch_dimension()] = input_batch;
1737 dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1738 for (int i = 0; i < num_spatial_dims; ++i) {
1739 dimensions[dnums.output_spatial_dimensions(i)] =
1740 window_output_shape.dimensions(i);
1741 }
1742 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1743 dimensions);
1744 }
1745
InferFftShape(const Shape & in,const FftType fft_type,const tensorflow::gtl::ArraySlice<int64> fft_length)1746 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1747 const Shape& in, const FftType fft_type,
1748 const tensorflow::gtl::ArraySlice<int64> fft_length) {
1749 const int64 fft_rank = fft_length.size();
1750 if (fft_rank < 1 || fft_rank > 3) {
1751 return InvalidArgument("FFT only supports ranks 1-3, but got %lld",
1752 fft_rank);
1753 }
1754 #define RET_CHECK_RANK(x) \
1755 if (x.dimensions_size() < fft_rank) { \
1756 return InvalidArgument( \
1757 "FFT of rank %lld requires input of at least " \
1758 "same rank; got input of rank %d", \
1759 fft_rank, x.dimensions_size()); \
1760 }
1761 switch (fft_type) {
1762 case FFT:
1763 case IFFT:
1764 if (in.element_type() != C64) {
1765 return InvalidArgument("%s requires C64 input type, found %s",
1766 FftType_Name(fft_type).c_str(),
1767 PrimitiveType_Name(in.element_type()).c_str());
1768 }
1769 RET_CHECK_RANK(in);
1770 return in;
1771 case RFFT: {
1772 if (in.element_type() != F32) {
1773 return InvalidArgument("RFFT requires F32 input type, found %s",
1774 PrimitiveType_Name(in.element_type()).c_str());
1775 }
1776 RET_CHECK_RANK(in);
1777 for (int i = 0; i < fft_rank; i++) {
1778 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1779 fft_length[i]) {
1780 return InvalidArgument(
1781 "RFFT requires innermost dimensions match fft_length but "
1782 "dimension %lld is %lld and should be %lld",
1783 in.dimensions_size() - fft_rank + i,
1784 in.dimensions(in.dimensions_size() - fft_rank + i),
1785 fft_length[i]);
1786 }
1787 }
1788 Shape result = ShapeUtil::ChangeElementType(in, C64);
1789 result.set_dimensions(result.dimensions_size() - 1,
1790 fft_length[fft_rank - 1] / 2 + 1);
1791 return result;
1792 }
1793 case IRFFT: {
1794 if (in.element_type() != C64) {
1795 return InvalidArgument("IRFFT requires C64 input type, found %s",
1796 PrimitiveType_Name(in.element_type()).c_str());
1797 }
1798 RET_CHECK_RANK(in);
1799 Shape result = ShapeUtil::ComplexComponentShape(in);
1800 for (int i = 0; i < fft_rank - 1; i++) {
1801 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1802 fft_length[i]) {
1803 return InvalidArgument(
1804 "IRFFT requires all but one innermost dimensions match "
1805 "fft_length, but dimension %lld is %lld and should be %lld",
1806 in.dimensions_size() - fft_rank + i,
1807 in.dimensions(in.dimensions_size() - fft_rank + i),
1808 fft_length[i]);
1809 }
1810 }
1811 if (in.dimensions(in.dimensions_size() - 1) !=
1812 fft_length[fft_rank - 1] / 2 + 1) {
1813 return InvalidArgument(
1814 "IRFFT requires innermost dimension matches fft_length/2+1, but "
1815 "dimension %d is %lld and should be %lld",
1816 in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1817 fft_length[fft_rank - 1] / 2 + 1);
1818 }
1819 result.set_dimensions(result.dimensions_size() - 1,
1820 fft_length[fft_rank - 1]);
1821 return result;
1822 }
1823 default:
1824 LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1825 }
1826 #undef RET_CHECK_RANK
1827 }
1828
InferCrossReplicaSumShape(tensorflow::gtl::ArraySlice<const Shape * > operand_shapes)1829 /* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
1830 tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
1831 for (const Shape* operand_shape : operand_shapes) {
1832 TF_RETURN_IF_ERROR(
1833 ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum"));
1834 }
1835 if (operand_shapes.size() == 1) {
1836 return *operand_shapes[0];
1837 }
1838 std::vector<Shape> operand_shape_values;
1839 for (const Shape* operand_shape : operand_shapes) {
1840 operand_shape_values.push_back(*operand_shape);
1841 }
1842 return ShapeUtil::MakeTupleShape(operand_shape_values);
1843 }
1844
InferReduceShape(const Shape & arg,const Shape & init_value,tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,const ProgramShape & to_apply)1845 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
1846 const Shape& arg, const Shape& init_value,
1847 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
1848 const ProgramShape& to_apply) {
1849 // Check that the dimension to reduce are in-bounds for the given shape.
1850 for (int64 dimension : dimensions_to_reduce) {
1851 if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
1852 return InvalidArgument(
1853 "attempting to reduce out-of-bounds dimension %lld in shape %s",
1854 dimension, ShapeUtil::HumanString(arg).c_str());
1855 }
1856 }
1857 TF_RETURN_IF_ERROR(
1858 VerifyReducerShape(to_apply, init_value, arg.element_type()));
1859
1860 std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
1861 dimensions_to_reduce.end());
1862 std::vector<int64> new_dimensions;
1863 for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
1864 if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
1865 new_dimensions.push_back(arg.dimensions(i));
1866 }
1867 }
1868
1869 return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
1870 }
1871
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)1872 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
1873 const Shape& operand_shape, const Shape& init_value_shape,
1874 const Window& window, const ProgramShape& to_apply_shape) {
1875 TF_RETURN_IF_ERROR(
1876 ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
1877 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
1878 operand_shape.element_type()));
1879 return InferWindowOutputShape(operand_shape, window,
1880 init_value_shape.element_type(),
1881 /*allow_negative_padding=*/false);
1882 }
1883
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)1884 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
1885 const Shape& operand_shape, const ProgramShape& select_shape,
1886 const Window& window, const Shape& source_shape,
1887 const Shape& init_value_shape, const ProgramShape& scatter_shape) {
1888 TF_RETURN_IF_ERROR(
1889 ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
1890
1891 // Check if the select function has a proper shape of (T,T) -> PRED.
1892 if (select_shape.parameters_size() != 2) {
1893 return InvalidArgument(
1894 "select function must take 2 parameters, but "
1895 "takes %d parameter(s).",
1896 select_shape.parameters_size());
1897 }
1898 const Shape& select_result_shape = select_shape.result();
1899 if (!ShapeUtil::Compatible(select_result_shape,
1900 ShapeUtil::MakeShape(PRED, {}))) {
1901 return Unimplemented("select function must have rank-0 PRED result.");
1902 }
1903 const Shape& operand_element_shape =
1904 ShapeUtil::MakeShape(operand_shape.element_type(), {});
1905 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
1906 select_shape.parameters(0))) {
1907 return InvalidArgument(
1908 "select function's first parameter shape currently must "
1909 "match the operand element shape. Got %s vs %s",
1910 ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
1911 ShapeUtil::HumanString(operand_element_shape).c_str());
1912 }
1913 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
1914 select_shape.parameters(1))) {
1915 return InvalidArgument(
1916 "select function's second parameter shape currently must "
1917 "match the operand element shape. Got %s vs %s",
1918 ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
1919 ShapeUtil::HumanString(operand_element_shape).c_str());
1920 }
1921
1922 // Check if the scatter function has a proper shape as a reduction.
1923 TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
1924 source_shape.element_type()));
1925
1926 // Check if the result shape of window operation matches the source shape.
1927 TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
1928 InferWindowOutputShape(operand_shape, window,
1929 operand_shape.element_type(),
1930 /*allow_negative_padding=*/false));
1931 if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
1932 window_result_shape)) {
1933 return InvalidArgument(
1934 "source shape does not match the shape of window-reduced operand: "
1935 "source(%s), window-reduced operand(%s)",
1936 ShapeUtil::HumanString(source_shape).c_str(),
1937 ShapeUtil::HumanString(window_result_shape).c_str());
1938 }
1939 return operand_shape;
1940 }
1941
InferSliceShape(const Shape & arg,tensorflow::gtl::ArraySlice<int64> starts,tensorflow::gtl::ArraySlice<int64> limits,tensorflow::gtl::ArraySlice<int64> strides)1942 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
1943 const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
1944 tensorflow::gtl::ArraySlice<int64> limits,
1945 tensorflow::gtl::ArraySlice<int64> strides) {
1946 auto error = [&](const string& message) {
1947 return InvalidArgument(
1948 "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
1949 "{%s}; strides: {%s}",
1950 message.c_str(), ShapeUtil::HumanString(arg).c_str(),
1951 Join(starts, ",").c_str(), Join(limits, ",").c_str(),
1952 Join(strides, ",").c_str());
1953 };
1954 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
1955 VLOG(2) << tensorflow::strings::Printf(
1956 "slicing shape %s starts={%s} limits={%s}",
1957 ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
1958 Join(limits, ", ").c_str());
1959
1960 if (starts.size() != limits.size()) {
1961 return error(Printf("slice start and limit sizes differ: %zu vs %zu",
1962 starts.size(), limits.size()));
1963 }
1964
1965 if (starts.size() != strides.size()) {
1966 return error(Printf("slice start and strides sizes differ: %zu vs %zu",
1967 starts.size(), strides.size()));
1968 }
1969
1970 if (starts.size() != ShapeUtil::Rank(arg)) {
1971 return InvalidArgument(
1972 "slice index count does not match argument rank: %zu vs %lld",
1973 starts.size(), ShapeUtil::Rank(arg));
1974 }
1975
1976 std::vector<int64> sizes;
1977 for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
1978 int64 start_index = starts[dimension];
1979 int64 limit_index = limits[dimension];
1980 int64 stride = strides[dimension];
1981 if (start_index < 0) {
1982 return InvalidArgument("negative start index to slice: %lld",
1983 start_index);
1984 }
1985 if (limit_index > arg.dimensions(dimension)) {
1986 return error(
1987 Printf("limit index (%lld) must be less than or equal to dimension "
1988 "size (%lld)",
1989 limit_index, arg.dimensions(dimension)));
1990 }
1991 VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
1992 start_index);
1993 VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
1994 limit_index);
1995 if (start_index > limit_index) {
1996 return error(
1997 Printf("limit index (%lld) must be greater or equal to "
1998 "start index (%lld) in slice with positive stride",
1999 limit_index, start_index));
2000 }
2001 if (stride <= 0) {
2002 return InvalidArgument("stride (%lld) must be positive", stride);
2003 }
2004 sizes.push_back((limit_index - start_index + stride - 1) / stride);
2005 }
2006
2007 return ShapeUtil::MakeShape(arg.element_type(), sizes);
2008 }
2009
InferDynamicSliceShape(const Shape & operand_shape,const Shape & start_indices_shape,tensorflow::gtl::ArraySlice<int64> slice_sizes)2010 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2011 const Shape& operand_shape, const Shape& start_indices_shape,
2012 tensorflow::gtl::ArraySlice<int64> slice_sizes) {
2013 TF_RETURN_IF_ERROR(
2014 ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
2015 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
2016 "start indices of dynamic slice"));
2017
2018 VLOG(2) << tensorflow::strings::Printf(
2019 "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2020 ShapeUtil::HumanString(operand_shape).c_str(),
2021 ShapeUtil::HumanString(start_indices_shape).c_str(),
2022 Join(slice_sizes, ", ").c_str());
2023
2024 if (ShapeUtil::Rank(start_indices_shape) != 1) {
2025 return InvalidArgument(
2026 "dynamic slice start indices of rank %lld must be rank1.",
2027 ShapeUtil::Rank(start_indices_shape));
2028 }
2029
2030 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2031 return InvalidArgument(
2032 "dynamic slice start indices must be of integral type.");
2033 }
2034
2035 const int64 start_num_dims = start_indices_shape.dimensions(0);
2036 if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
2037 return InvalidArgument(
2038 "dynamic slice start number of dimensions %lld (%s) must match rank "
2039 "%lld of slice input (%s)",
2040 start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
2041 ShapeUtil::Rank(operand_shape),
2042 ShapeUtil::HumanString(operand_shape).c_str());
2043 }
2044
2045 if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
2046 return InvalidArgument(
2047 "dynamic slice index count does not match argument rank: %zu vs %lld",
2048 slice_sizes.size(), ShapeUtil::Rank(operand_shape));
2049 }
2050
2051 for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2052 const int64 input_dim_size = operand_shape.dimensions(dim);
2053 const int64 slice_dim_size = slice_sizes[dim];
2054 if (slice_dim_size < 0) {
2055 return InvalidArgument("negative size index to dynamic slice: %lld",
2056 slice_dim_size);
2057 }
2058 if (slice_dim_size > input_dim_size) {
2059 return InvalidArgument(
2060 "slice dim size %lld greater than dynamic slice dimension: %lld",
2061 slice_dim_size, input_dim_size);
2062 }
2063 VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
2064 slice_dim_size);
2065 }
2066
2067 return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2068 }
2069
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,const Shape & start_indices_shape)2070 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2071 const Shape& operand_shape, const Shape& update_shape,
2072 const Shape& start_indices_shape) {
2073 TF_RETURN_IF_ERROR(
2074 ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
2075 TF_RETURN_IF_ERROR(
2076 ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
2077 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
2078 start_indices_shape, "start indices of dynamic update slice"));
2079
2080 VLOG(2) << tensorflow::strings::Printf(
2081 "updating slice of shape %s at dynamic start_indices %s with update "
2082 "shape %s",
2083 ShapeUtil::HumanString(operand_shape).c_str(),
2084 ShapeUtil::HumanString(start_indices_shape).c_str(),
2085 ShapeUtil::HumanString(update_shape).c_str());
2086
2087 if (ShapeUtil::Rank(start_indices_shape) != 1) {
2088 return InvalidArgument(
2089 "dynamic update slice start indices of rank %lld must be rank1.",
2090 ShapeUtil::Rank(start_indices_shape));
2091 }
2092
2093 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2094 return InvalidArgument(
2095 "dynamic update slice start indices must be of integral type.");
2096 }
2097
2098 const int64 start_num_dims = start_indices_shape.dimensions(0);
2099 if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
2100 return InvalidArgument(
2101 "dynamic slice start number of dimensions %lld (%s) must match rank "
2102 "%lld of slice input (%s)",
2103 start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(),
2104 ShapeUtil::Rank(operand_shape),
2105 ShapeUtil::HumanString(operand_shape).c_str());
2106 }
2107
2108 if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
2109 return InvalidArgument(
2110 "dynamic update slice update rank does not match argument rank: "
2111 "%lld vs %lld",
2112 ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
2113 }
2114
2115 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2116 update_shape)) {
2117 return InvalidArgument(
2118 "dynamic update slice update element type does not match argument. "
2119 "operand.element_type: %s vs update.element_type: %s",
2120 PrimitiveType_Name(operand_shape.element_type()).c_str(),
2121 PrimitiveType_Name(update_shape.element_type()).c_str());
2122 }
2123
2124 for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
2125 const int64 input_dim_size = operand_shape.dimensions(dim);
2126 const int64 update_dim_size = update_shape.dimensions(dim);
2127 if (update_dim_size < 0) {
2128 return InvalidArgument(
2129 "size index %lld to dynamic update slice must be >= 0",
2130 update_dim_size);
2131 }
2132 if (update_dim_size > input_dim_size) {
2133 return InvalidArgument(
2134 "update dim size %lld greater than dynamic slice dimension: %lld",
2135 update_dim_size, input_dim_size);
2136 }
2137 VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
2138 update_dim_size);
2139 }
2140
2141 return operand_shape;
2142 }
2143
InferReverseShape(const Shape & operand_shape,tensorflow::gtl::ArraySlice<int64> dimensions)2144 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2145 const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
2146 TF_RETURN_IF_ERROR(
2147 ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
2148 if (!AllUnique(dimensions)) {
2149 return InvalidArgument("a dimension number is duplicated in reverse");
2150 }
2151 for (int64 dimension : dimensions) {
2152 if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
2153 return InvalidArgument(
2154 "one of the reverse dimensions (%lld) is out-of-bounds in shape %s",
2155 dimension, ShapeUtil::HumanString(operand_shape).c_str());
2156 }
2157 }
2158 return operand_shape;
2159 }
2160
InferGetTupleElementShape(const Shape & arg,int64 index)2161 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2162 const Shape& arg, int64 index) {
2163 if (!ShapeUtil::IsTuple(arg)) {
2164 return InvalidArgument(
2165 "cannot infer shape: attempting to index into non-tuple: %s",
2166 ShapeUtil::HumanString(arg).c_str());
2167 }
2168
2169 if (index >= arg.tuple_shapes_size()) {
2170 return InvalidArgument(
2171 "cannot infer shape: attempt to index out of tuple bounds: %lld "
2172 ">= %d in shape %s",
2173 index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
2174 }
2175
2176 return arg.tuple_shapes(index);
2177 }
2178
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2179 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2180 const ProgramShape& condition, const ProgramShape& body,
2181 const Shape& init) {
2182 // Check the number of parameters for given computations.
2183 if (condition.parameters_size() != 1) {
2184 return InvalidArgument("condition must take 1 arguments; got %d",
2185 condition.parameters_size());
2186 }
2187 if (body.parameters_size() != 1) {
2188 return InvalidArgument("body must take 1 arguments; got %d",
2189 body.parameters_size());
2190 }
2191
2192 auto shape_string = [&]() {
2193 return tensorflow::strings::Printf(
2194 "condition: %s; body: %s; init: %s",
2195 ShapeUtil::HumanString(condition).c_str(),
2196 ShapeUtil::HumanString(body).c_str(),
2197 ShapeUtil::HumanString(init).c_str());
2198 };
2199
2200 // Check the shapes of computation parameters and return types.
2201 if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
2202 return InvalidArgument("condition must return a boolean; got %s",
2203 shape_string().c_str());
2204 }
2205 if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2206 !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2207 !ShapeUtil::Compatible(body.result(), init)) {
2208 return InvalidArgument(
2209 "the parameter of condition and body, the result of the body, and init "
2210 "must all have the same shape; got %s",
2211 shape_string().c_str());
2212 }
2213
2214 return init;
2215 }
2216
InferConditionalShape(const Shape & predicate,const Shape & true_operand,const Shape & false_operand,const ProgramShape & true_computation,const ProgramShape & false_computation)2217 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2218 const Shape& predicate, const Shape& true_operand,
2219 const Shape& false_operand, const ProgramShape& true_computation,
2220 const ProgramShape& false_computation) {
2221 if (!ShapeUtil::ShapeIs(predicate, PRED, {})) {
2222 return InvalidArgument("predicate must be a boolean; got %s.",
2223 ShapeUtil::HumanString(predicate).c_str());
2224 }
2225
2226 if (true_computation.parameters_size() != 1) {
2227 return InvalidArgument("true_computation must take 1 argument; got %d.",
2228 true_computation.parameters_size());
2229 }
2230 if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) {
2231 auto true_shape_string = [&]() {
2232 return tensorflow::strings::Printf(
2233 "true_operand: %s; true_computation: %s",
2234 ShapeUtil::HumanString(true_operand).c_str(),
2235 ShapeUtil::HumanString(true_computation).c_str());
2236 };
2237 return InvalidArgument(
2238 "true_operand must match the shape of the only parameter of "
2239 "true_computation: got %s.",
2240 true_shape_string().c_str());
2241 }
2242
2243 if (false_computation.parameters_size() != 1) {
2244 return InvalidArgument("false_computation must take 1 argument; got %d.",
2245 false_computation.parameters_size());
2246 }
2247 if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) {
2248 auto false_shape_string = [&]() {
2249 return tensorflow::strings::Printf(
2250 "false_operand: %s; false_computation: %s",
2251 ShapeUtil::HumanString(false_operand).c_str(),
2252 ShapeUtil::HumanString(false_computation).c_str());
2253 };
2254 return InvalidArgument(
2255 "false_operand must match the shape of the only parameter of "
2256 "false_computation: got %s.",
2257 false_shape_string().c_str());
2258 }
2259 if (!ShapeUtil::Compatible(true_computation.result(),
2260 false_computation.result())) {
2261 auto shape_string = [&]() {
2262 return tensorflow::strings::Printf(
2263 "true_computation result: %s; false_computation result: %s.",
2264 ShapeUtil::HumanString(true_computation.result()).c_str(),
2265 ShapeUtil::HumanString(false_computation.result()).c_str());
2266 };
2267 return InvalidArgument(
2268 "the result of true_computation and false_computation must have the "
2269 "same shape: got %s.",
2270 shape_string().c_str());
2271 }
2272 return true_computation.result();
2273 }
2274
InferBroadcastShape(const Shape & operand,tensorflow::gtl::ArraySlice<int64> broadcast_sizes)2275 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2276 const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
2277 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
2278 for (int64 size : broadcast_sizes) {
2279 if (size < 0) {
2280 return InvalidArgument("Broadcast with negative dimension size %lld.",
2281 size);
2282 }
2283 }
2284
2285 std::vector<int64> dimensions(operand.dimensions_size() +
2286 broadcast_sizes.size());
2287 std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2288 std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2289 dimensions.begin() + broadcast_sizes.size());
2290 return ShapeUtil::MakeShape(operand.element_type(), dimensions);
2291 }
2292
InferReshapeShape(const Shape & operand,tensorflow::gtl::ArraySlice<int64> dimensions,tensorflow::gtl::ArraySlice<int64> new_sizes)2293 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2294 const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
2295 tensorflow::gtl::ArraySlice<int64> new_sizes) {
2296 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
2297
2298 Shape inferred_shape =
2299 ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2300 VLOG(3) << "Reshape inferred shape: "
2301 << ShapeUtil::HumanString(inferred_shape);
2302
2303 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2304 return InvalidArgument(
2305 "reshape operation has mismatched element counts: from=%lld (%s) "
2306 "to=%lld (%s)",
2307 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(),
2308 ShapeUtil::ElementsIn(inferred_shape),
2309 ShapeUtil::HumanString(inferred_shape).c_str());
2310 }
2311
2312 std::vector<int64> indices(ShapeUtil::Rank(operand));
2313 std::iota(indices.begin(), indices.end(), 0);
2314 if (dimensions.size() != ShapeUtil::Rank(operand) ||
2315 !std::is_permutation(dimensions.begin(), dimensions.end(),
2316 indices.begin())) {
2317 return InvalidArgument(
2318 "Reshape dimensions [%s] are not a permutation of the operand "
2319 "dimensions (operand shape is %s).",
2320 Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str());
2321 }
2322
2323 return inferred_shape;
2324 }
2325
InferTransposeShape(const Shape & operand,tensorflow::gtl::ArraySlice<int64> dimensions)2326 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
2327 const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
2328 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
2329
2330 std::vector<int64> indices(ShapeUtil::Rank(operand));
2331 std::iota(indices.begin(), indices.end(), 0);
2332 if (dimensions.size() != ShapeUtil::Rank(operand) ||
2333 !std::is_permutation(dimensions.begin(), dimensions.end(),
2334 indices.begin())) {
2335 return InvalidArgument(
2336 "Transpose dimensions not a permutation of the operand dimensions.");
2337 }
2338
2339 // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
2340 // we need output[i]=input[dimensions[i]] which is
2341 // Permute(Inverse(dimensions),input).
2342 return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
2343 }
2344
2345 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2346 // "degenerate" cases, as with binary elementwise ops.
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)2347 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
2348 const Shape& min, const Shape& operand, const Shape& max) {
2349 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
2350 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
2351 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
2352 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
2353 !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
2354 return InvalidArgument("clamp op with different operand types: %s, %s, %s",
2355 ShapeUtil::HumanString(min).c_str(),
2356 ShapeUtil::HumanString(operand).c_str(),
2357 ShapeUtil::HumanString(max).c_str());
2358 }
2359 if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
2360 ShapeUtil::IsScalar(min)) &&
2361 (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
2362 ShapeUtil::IsScalar(max)))) {
2363 return operand;
2364 }
2365 if (ShapeUtil::IsScalar(operand)) {
2366 if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
2367 return ShapeUtil::ChangeElementType(min, operand.element_type());
2368 } else if (ShapeUtil::IsScalar(min)) {
2369 return ShapeUtil::ChangeElementType(max, operand.element_type());
2370 } else if (ShapeUtil::IsScalar(max)) {
2371 return ShapeUtil::ChangeElementType(min, operand.element_type());
2372 }
2373 }
2374 return Unimplemented(
2375 "not yet implemented: %s, %s <clamp> %s", min.ShortDebugString().c_str(),
2376 max.ShortDebugString().c_str(), operand.ShortDebugString().c_str());
2377 }
2378
2379 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2380 // "degenerate" cases, as with binary elementwise ops, as well as scalar
2381 // broadcast from all operands, not just the predicate.
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2382 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
2383 const Shape& pred, const Shape& on_true, const Shape& on_false) {
2384 bool compatible;
2385 if (ShapeUtil::IsTuple(on_true)) {
2386 // Select only defines the top-level buffer, so if it's a tuple, the two
2387 // input must match exactly.
2388 compatible = ShapeUtil::Compatible(on_true, on_false);
2389 } else {
2390 compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
2391 }
2392 if (!compatible) {
2393 return InvalidArgument(
2394 "operands to select must be the same shape; got %s and %s",
2395 ShapeUtil::HumanString(on_true).c_str(),
2396 ShapeUtil::HumanString(on_false).c_str());
2397 }
2398 if (pred.element_type() != PRED) {
2399 return InvalidArgument(
2400 "select's pred operand must have PRED element type; got %s",
2401 ShapeUtil::HumanString(pred).c_str());
2402 }
2403 if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) {
2404 // By this stage we know that pred's element type is PRED. Therefore, this
2405 // check restricts pred to be a PRED scalar, or a PRED array with the same
2406 // dimensions as on_true and on_false.
2407 return ShapeUtil::ChangeElementType(
2408 on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
2409 } else {
2410 return Unimplemented(
2411 "select operation with non-scalar predicate with dimensionality "
2412 " different from the other operands: %s",
2413 ShapeUtil::HumanString(pred).c_str());
2414 }
2415 }
2416
InferCallShape(tensorflow::gtl::ArraySlice<const Shape * > arg_shapes,const ProgramShape & to_apply)2417 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
2418 tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
2419 const ProgramShape& to_apply) {
2420 // The applied function's arity equals the number of arguments.
2421 if (arg_shapes.size() != to_apply.parameters_size()) {
2422 string computation_signature = ShapeUtil::HumanString(to_apply);
2423 string argument_shapes =
2424 Join(arg_shapes, ", ", [](string* out, const Shape* shape) {
2425 tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape));
2426 });
2427 return InvalidArgument(
2428 "Call applied function arity must match number of arguments; got: "
2429 "arity: %d, arguments: %zu; computation signature: %s; argument "
2430 "shapes: [%s]",
2431 to_apply.parameters_size(), arg_shapes.size(),
2432 computation_signature.c_str(), argument_shapes.c_str());
2433 }
2434
2435 // All arguments must be compatible with the program shape.
2436 for (int i = 0; i < arg_shapes.size(); ++i) {
2437 const Shape& arg_shape = *arg_shapes[i];
2438 const Shape& param_shape = to_apply.parameters(i);
2439 if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
2440 return InvalidArgument(
2441 "Call parameter must match argument; got parameter %d shape: %s, "
2442 "argument shape: %s",
2443 i, ShapeUtil::HumanString(param_shape).c_str(),
2444 ShapeUtil::HumanString(arg_shape).c_str());
2445 }
2446 }
2447
2448 return to_apply.result();
2449 }
2450
ValidateGatherDimensionNumbers(const Shape & input_shape,tensorflow::gtl::ArraySlice<int64> gather_indices_shape,const GatherDimensionNumbers & dim_numbers)2451 static Status ValidateGatherDimensionNumbers(
2452 const Shape& input_shape,
2453 tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
2454 const GatherDimensionNumbers& dim_numbers) {
2455 if (!c_is_sorted(dim_numbers.output_window_dims())) {
2456 return InvalidArgument(
2457 "Output window dimensions in gather op must be ascending; got: %s",
2458 Join(dim_numbers.output_window_dims(), ", ").c_str());
2459 }
2460
2461 if (c_adjacent_find(dim_numbers.output_window_dims()) !=
2462 dim_numbers.output_window_dims().end()) {
2463 return InvalidArgument(
2464 "Output window dimensions in gather op must not repeat; got: %s",
2465 Join(dim_numbers.output_window_dims(), ", ").c_str());
2466 }
2467
2468 const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
2469 const int64 output_shape_rank =
2470 output_window_dim_count + gather_indices_shape.size();
2471
2472 for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
2473 int64 window_index = dim_numbers.output_window_dims(i);
2474 if (window_index < 0 || window_index >= output_shape_rank) {
2475 return InvalidArgument(
2476 "Window index %d in gather op is out of bounds; got %lld, but should "
2477 "have been in"
2478 "[0,%lld)",
2479 i, window_index, output_shape_rank);
2480 }
2481 }
2482
2483 if (dim_numbers.gather_dims_to_operand_dims_size() !=
2484 gather_indices_shape.back()) {
2485 return InvalidArgument(
2486 "There must be exactly as many elements in gather_dims_to_operand_dims "
2487 "as there are elements in the last dimension of %%gather_indices; got: "
2488 "%d, expected %lld",
2489 dim_numbers.gather_dims_to_operand_dims_size(),
2490 gather_indices_shape.back());
2491 }
2492
2493 for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
2494 int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
2495 if (gather_dim_to_input_dim < 0 ||
2496 gather_dim_to_input_dim >= input_shape.dimensions_size()) {
2497 return InvalidArgument(
2498 "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
2499 "got: %d->%lld",
2500 input_shape.dimensions_size(), i, gather_dim_to_input_dim);
2501 }
2502 }
2503
2504 std::vector<int64> sorted_gather_dims_to_operand_dims(
2505 dim_numbers.gather_dims_to_operand_dims().begin(),
2506 dim_numbers.gather_dims_to_operand_dims().end());
2507
2508 c_sort(sorted_gather_dims_to_operand_dims);
2509
2510 if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
2511 sorted_gather_dims_to_operand_dims.end()) {
2512 return InvalidArgument(
2513 "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
2514 "got: %s",
2515 Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
2516 }
2517
2518 for (int64 elided_dim : dim_numbers.elided_window_dims()) {
2519 if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
2520 return InvalidArgument(
2521 "Invalid elided_window_dims set in gather op; valid range is [0, "
2522 "%d), got: %lld",
2523 input_shape.dimensions_size(), elided_dim);
2524 }
2525 }
2526
2527 if (!c_is_sorted(dim_numbers.elided_window_dims())) {
2528 return InvalidArgument(
2529 "elided_window_dims in gather op must be sorted; got: %s",
2530 Join(dim_numbers.elided_window_dims(), ", ").c_str());
2531 }
2532
2533 if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
2534 dim_numbers.elided_window_dims().end()) {
2535 return InvalidArgument(
2536 "Repeated dimensions not allowed in elided_window_dims in gather op; "
2537 "got: %s",
2538 Join(dim_numbers.elided_window_dims(), ", ").c_str());
2539 }
2540
2541 return Status::OK();
2542 }
2543
InferGatherShape(const Shape & input_shape,const Shape & gather_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,tensorflow::gtl::ArraySlice<int64> window_bounds)2544 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
2545 const Shape& input_shape, const Shape& gather_indices_shape,
2546 const GatherDimensionNumbers& gather_dim_numbers,
2547 tensorflow::gtl::ArraySlice<int64> window_bounds) {
2548 TF_RETURN_IF_ERROR(
2549 ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
2550 TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
2551 gather_indices_shape, "gather indices operand of gather op"));
2552
2553 if (gather_indices_shape.dimensions_size() < 1) {
2554 return InvalidArgument(
2555 "Gather indices parameter must at least of rank 1; got %s",
2556 ShapeUtil::HumanString(gather_indices_shape).c_str());
2557 }
2558
2559 if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
2560 return InvalidArgument(
2561 "Gather indices parameter must be an integral tensor; got %s",
2562 ShapeUtil::HumanString(gather_indices_shape).c_str());
2563 }
2564
2565 std::vector<int64> expanded_gather_indices_shape;
2566 // We implicitly reshape gather indices of shape P[N] to P[N,1].
2567 expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
2568 c_copy(gather_indices_shape.dimensions(),
2569 std::back_inserter(expanded_gather_indices_shape));
2570 if (expanded_gather_indices_shape.size() == 1) {
2571 expanded_gather_indices_shape.push_back(1);
2572 }
2573
2574 TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
2575 input_shape, expanded_gather_indices_shape, gather_dim_numbers));
2576
2577 if (window_bounds.size() != input_shape.dimensions_size()) {
2578 return InvalidArgument(
2579 "Gather op must have one window bound for every input dimension; got: "
2580 "len(window_bounds)=%lu, input_shape.rank=%d",
2581 window_bounds.size(), input_shape.dimensions_size());
2582 }
2583
2584 if (window_bounds.size() !=
2585 gather_dim_numbers.output_window_dims_size() +
2586 gather_dim_numbers.elided_window_dims_size()) {
2587 return InvalidArgument(
2588 "All components of the window index in a gather op must either be a "
2589 "output window index or explicitly elided; got len(window_bounds)=%lu, "
2590 "output_window_bounds=%s, elided_window_bounds=%s",
2591 window_bounds.size(),
2592 Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
2593 Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
2594 }
2595
2596 for (int i = 0; i < window_bounds.size(); i++) {
2597 int64 window_bound = window_bounds[i];
2598 int64 corresponding_input_bound = input_shape.dimensions(i);
2599 if (window_bound < 0 || window_bound > corresponding_input_bound) {
2600 return InvalidArgument(
2601 "Window bound at index %d in gather op is out of range, must be "
2602 "within "
2603 "[0, %lld), got %lld",
2604 i, corresponding_input_bound + 1, window_bound);
2605 }
2606 }
2607
2608 for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
2609 if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
2610 return InvalidArgument(
2611 "Gather op can only elide window indices with bound 1, but bound is "
2612 "%lld for index %lld at position %d",
2613 window_bounds[gather_dim_numbers.elided_window_dims(i)],
2614 gather_dim_numbers.elided_window_dims(i), i);
2615 }
2616 }
2617
2618 int64 result_rank = gather_dim_numbers.output_window_dims_size() +
2619 (expanded_gather_indices_shape.size() - 1);
2620 int64 window_dims_seen = 0;
2621 int64 gather_dims_seen = 0;
2622 std::vector<int64> output_dim_bounds;
2623 output_dim_bounds.reserve(result_rank);
2624 for (int64 i = 0; i < result_rank; i++) {
2625 int64 current_bound;
2626 bool is_window_index =
2627 c_binary_search(gather_dim_numbers.output_window_dims(), i);
2628 if (is_window_index) {
2629 while (c_binary_search(gather_dim_numbers.elided_window_dims(),
2630 window_dims_seen)) {
2631 window_dims_seen++;
2632 }
2633 current_bound = window_bounds[window_dims_seen++];
2634 } else {
2635 current_bound = expanded_gather_indices_shape[gather_dims_seen++];
2636 }
2637
2638 output_dim_bounds.push_back(current_bound);
2639 }
2640
2641 return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
2642 }
2643
2644 } // namespace xla
2645