1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/permutation_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/math/math_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42
43 namespace xla {
44 namespace {
45
46 using absl::StrFormat;
47 using absl::StrJoin;
48
49 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)50 bool AllUnique(absl::Span<const int64> slice) {
51 return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
52 }
53
ExpectArray(const Shape & shape,absl::string_view op_type)54 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
55 if (!shape.IsArray()) {
56 return InvalidArgument("Expected array argument for %s, but got %s.",
57 string(op_type), ShapeUtil::HumanString(shape));
58 }
59 return Status::OK();
60 }
61
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64 inputs)62 Status VerifyReducerShape(const ProgramShape& reducer_shape,
63 absl::Span<const Shape* const> init_value_shapes,
64 absl::Span<const PrimitiveType> input_element_types,
65 int64 inputs) {
66 if (reducer_shape.parameters_size() != inputs * 2) {
67 return InvalidArgument(
68 "Reduction function must take %d parameters, but "
69 "takes %d parameter(s).",
70 inputs * 2, reducer_shape.parameters_size());
71 }
72
73 const Shape& accumulator_shape = reducer_shape.result();
74 std::vector<const Shape*> accumulator_subshapes;
75 if (accumulator_shape.IsArray()) {
76 if (inputs != 1) {
77 return InvalidArgument(
78 "Reduction function must produce a tuple with %d elements, but "
79 "produces a scalar",
80 inputs);
81 }
82 accumulator_subshapes.push_back(&accumulator_shape);
83 } else if (accumulator_shape.IsTuple()) {
84 if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
85 return InvalidArgument(
86 "Reduction function must produce a tuple with %d elements, but has "
87 "%d elements",
88 inputs, ShapeUtil::TupleElementCount(accumulator_shape));
89 }
90 for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
91 accumulator_subshapes.push_back(&element_shape);
92 }
93 } else {
94 return InvalidArgument(
95 "Reduction function must produce a scalar or tuple of scalars, but has "
96 "shape: %s",
97 ShapeUtil::HumanString(accumulator_shape));
98 }
99
100 for (const Shape* element_shape : accumulator_subshapes) {
101 if (element_shape->rank() != 0) {
102 return InvalidArgument(
103 "Reduction function must return a scalar or tuple of scalars but "
104 "returns shape: %s",
105 ShapeUtil::HumanString(accumulator_shape));
106 }
107 }
108
109 for (int64 i = 0; i < inputs; ++i) {
110 // Check that the accumulator can be passed in as the first argument.
111 // Note: comparing here and below with Compatible since we don't care about
112 // layout in scalars - see b/26668201 for a longer-term vision.
113 if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
114 reducer_shape.parameters(i))) {
115 return InvalidArgument(
116 "Reduction function's %d-th parameter shape differs from the "
117 "result shape: %s vs %s",
118 i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
119 ShapeUtil::HumanString(*accumulator_subshapes[i]));
120 }
121 // Check that init_value's shapes are suitable for reducer_shape.
122 if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
123 *init_value_shapes[i])) {
124 return InvalidArgument(
125 "Reduction function's accumulator shape at index %d differs from "
126 "the init_value shape: %s vs %s",
127 i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
128 ShapeUtil::HumanString(*init_value_shapes[i]));
129 }
130 // Check that the inputs can be passed in as the non-accumulator arguments.
131 const Shape input_element_shape =
132 ShapeUtil::MakeShape(input_element_types[i], {});
133 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
134 input_element_shape, reducer_shape.parameters(inputs + i))) {
135 return InvalidArgument(
136 "Reduction function's %d-th parameter shape differs from the "
137 "input type element type: %s vs %s",
138 inputs + i,
139 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
140 ShapeUtil::HumanString(input_element_shape));
141 }
142 // Check that the accumulator and inputs to the reducer function match.
143 // If the accumulator is scalar, it must have the same type as the inputs
144 // (up to fp precision). If it is a tuple, then the k-th element of the
145 // tuple must have the same type as the K-th input (again, up to fp
146 // precision.)
147 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
148 *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
149 return InvalidArgument(
150 "Reduction function's %d-th parameter shape must "
151 "match the result shape, but got %s vs %s.",
152 inputs + i,
153 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
154 ShapeUtil::HumanString(*accumulator_subshapes[i]));
155 }
156 }
157
158 return Status::OK();
159 }
160
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)161 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
162 const Window& window,
163 PrimitiveType element_type,
164 bool allow_negative_padding) {
165 if (window.dimensions_size() != base_shape.rank()) {
166 return InvalidArgument(
167 "Window has dimension %d but base shape has dimension %d.",
168 window.dimensions_size(), base_shape.rank());
169 }
170
171 std::vector<int64> output_dimensions(window.dimensions_size());
172 std::vector<bool> output_is_dynamic(window.dimensions_size());
173 for (int64 i = 0; i < window.dimensions_size(); ++i) {
174 const auto& dim = window.dimensions(i);
175 if (dim.size() <= 0) {
176 return InvalidArgument("Window %s has a non-positive dimension.",
177 window.DebugString());
178 }
179 if (dim.stride() <= 0) {
180 return InvalidArgument("Window %s has a non-positive stride.",
181 window.DebugString());
182 }
183 if (!allow_negative_padding && dim.padding_low() < 0) {
184 return InvalidArgument("Window %s has a negative low padding.",
185 window.DebugString());
186 }
187 if (!allow_negative_padding && dim.padding_high() < 0) {
188 return InvalidArgument("Window %s has a negative high padding.",
189 window.DebugString());
190 }
191 if (dim.base_dilation() < 1) {
192 return InvalidArgument(
193 "Window %s has a non-positive base area dilation factor.",
194 window.DebugString());
195 }
196 if (dim.window_dilation() < 1) {
197 return InvalidArgument(
198 "Window %s has a non-positive window dilation factor.",
199 window.DebugString());
200 }
201
202 const int64 dilated_base = window_util::DilatedBound(
203 ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
204 const int64 padded_dilated_base =
205 dim.padding_low() + dilated_base + dim.padding_high();
206 const int64 dilated_window =
207 window_util::DilatedBound(dim.size(), dim.window_dilation());
208
209 output_dimensions[i] = window_util::StridedBound(
210 padded_dilated_base, dilated_window, dim.stride());
211 output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
212 }
213
214 return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
215 output_is_dynamic);
216 }
217
MaybeUpcast(PrimitiveType from_type,absl::optional<PrimitiveType> preferred_element_type)218 StatusOr<PrimitiveType> MaybeUpcast(
219 PrimitiveType from_type,
220 absl::optional<PrimitiveType> preferred_element_type) {
221 if (!preferred_element_type.has_value() ||
222 *preferred_element_type == from_type) {
223 return from_type;
224 }
225 if (primitive_util::IsIntegralType(from_type) !=
226 primitive_util::IsIntegralType(*preferred_element_type)) {
227 return InvalidArgument(
228 "`preferred_element_type` and the original type must both be integral "
229 "or both be floating point.");
230 }
231 if (!primitive_util::IsSignedIntegralType(from_type) !=
232 !primitive_util::IsSignedIntegralType(*preferred_element_type)) {
233 return InvalidArgument(
234 "`preferred_element_type` must have the same signedness as the "
235 "original type.");
236 }
237 if (!primitive_util::IsFloatingPointType(from_type) &&
238 primitive_util::BitWidth(*preferred_element_type) <
239 primitive_util::BitWidth(from_type)) {
240 return InvalidArgument(
241 "`preferred_element_type` must not be narrower than the original "
242 "type.");
243 }
244 return *preferred_element_type;
245 }
246
247 } // namespace
248
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)249 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
250 HloOpcode opcode, const HloInstruction* operand) {
251 return InferUnaryOpShape(opcode, operand->shape());
252 }
253
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)254 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
255 HloOpcode opcode, const Shape& shape) {
256 // There is no copy operation at the proto level, so handle copy explicitly.
257 // A domain shape is the same as the input one.
258 if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
259 return shape;
260 }
261
262 TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
263
264 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
265 switch (opcode) {
266 case HloOpcode::kFloor:
267 case HloOpcode::kCeil:
268 case HloOpcode::kRoundNearestAfz:
269 if (!ShapeUtil::ElementIsFloating(shape)) {
270 return InvalidArgument(
271 "Expected element type in shape to be floating for %s operation; "
272 "got %s.",
273 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
274 }
275 return shape;
276 case HloOpcode::kCos:
277 case HloOpcode::kSin:
278 case HloOpcode::kExp:
279 case HloOpcode::kExpm1:
280 case HloOpcode::kLog:
281 case HloOpcode::kLog1p:
282 case HloOpcode::kLogistic:
283 case HloOpcode::kRsqrt:
284 case HloOpcode::kSqrt:
285 case HloOpcode::kCbrt:
286 case HloOpcode::kTanh:
287 if (!ShapeUtil::ElementIsFloating(shape) &&
288 !ShapeUtil::ElementIsComplex(shape)) {
289 return InvalidArgument(
290 "Expected element type in shape to be floating or complex for %s "
291 "operation; got %s.",
292 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
293 }
294 return shape;
295 case HloOpcode::kReal:
296 case HloOpcode::kImag:
297 if (ShapeUtil::ElementIsComplex(shape)) {
298 return ShapeUtil::ComplexComponentShape(shape);
299 } else if (ShapeUtil::ElementIsFloating(shape)) {
300 return shape;
301 } else {
302 return InvalidArgument(
303 "Expected element type in shape to be floating or complex for "
304 "%s operation; got %s.",
305 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
306 }
307 case HloOpcode::kAbs:
308 if (ShapeUtil::ElementIsComplex(shape)) {
309 return ShapeUtil::ChangeElementType(
310 shape, primitive_util::ComplexComponentType(shape.element_type()));
311 } else if (ShapeUtil::ElementIsSigned(shape)) {
312 return shape;
313 } else {
314 return InvalidArgument(
315 "Expected element type in shape to be floating or complex for "
316 "%s operation; got %s.",
317 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
318 }
319 case HloOpcode::kClz:
320 if (!ShapeUtil::ElementIsIntegral(shape)) {
321 return InvalidArgument(
322 "Expected an integral element type in argument to Clz "
323 "operation; got %s.",
324 PrimitiveType_Name(shape.element_type()));
325 }
326 return shape;
327 case HloOpcode::kNegate:
328 if (!ShapeUtil::ElementIsIntegral(shape) &&
329 !ShapeUtil::ElementIsFloating(shape) &&
330 !ShapeUtil::ElementIsComplex(shape)) {
331 return InvalidArgument(
332 "Expected element type in shape to be integral, floating or "
333 "complex for %s operation; got %s.",
334 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
335 }
336 return shape;
337 case HloOpcode::kPopulationCount:
338 if (!ShapeUtil::ElementIsIntegral(shape)) {
339 return InvalidArgument(
340 "Expected an integral element type in argument to PopulationCount "
341 "operation; got %s.",
342 PrimitiveType_Name(shape.element_type()));
343 }
344 return shape;
345 case HloOpcode::kSign:
346 if (!ShapeUtil::ElementIsSigned(shape) &&
347 !ShapeUtil::ElementIsComplex(shape)) {
348 return InvalidArgument(
349 "Expected element type in shape to be signed or complex for "
350 "%s operation; got %s.",
351 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
352 }
353 return shape;
354
355 case HloOpcode::kNot:
356 if (shape.element_type() != PRED &&
357 !primitive_util::IsIntegralType(shape.element_type())) {
358 return InvalidArgument(
359 "Expected pred or an integral element type in argument to Not "
360 "operation; got %s.",
361 PrimitiveType_Name(shape.element_type()));
362 }
363 return shape;
364
365 case HloOpcode::kIsFinite:
366 if (!ShapeUtil::ElementIsFloating(shape)) {
367 return InvalidArgument(
368 "Expected element type in shape to be floating "
369 "point for IsFinite "
370 "operation; got %s.",
371 PrimitiveType_Name(shape.element_type()));
372 }
373 return ShapeUtil::ChangeElementType(shape, PRED);
374
375 default:
376 return InvalidArgument(
377 "Unknown operation for unary shape inference: \"%s\".",
378 HloOpcodeString(opcode));
379 }
380 }
381
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64 dimension)382 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
383 absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
384 if (arg_shapes.empty()) {
385 return InvalidArgument("Concatenate expects at least one argument.");
386 }
387 if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
388 return InvalidArgument("Concatenate dimension out of bounds: %d.",
389 dimension);
390 }
391 const Shape* arg_shape = nullptr;
392 PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
393 for (const Shape* shape : arg_shapes) {
394 TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
395 if (!arg_shape) {
396 arg_shape = shape;
397 element_type = arg_shape->element_type();
398 continue;
399 }
400 if (arg_shape->rank() != shape->rank()) {
401 return InvalidArgument(
402 "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
403 "(%s).",
404 arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
405 ShapeUtil::HumanString(*shape));
406 }
407 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
408 return InvalidArgument(
409 "Cannot concatenate arrays with different element types: %s vs %s.",
410 PrimitiveType_Name(arg_shape->element_type()),
411 PrimitiveType_Name(shape->element_type()));
412 }
413 for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
414 ++dimension_number) {
415 if (arg_shape->dimensions(dimension_number) !=
416 shape->dimensions(dimension_number)) {
417 if (dimension_number == dimension) {
418 continue; // It's okay to differ in the dimension we're
419 // concatenating.
420 }
421 return InvalidArgument(
422 "Cannot concatenate arrays that differ in dimensions other than "
423 "the one being concatenated (the other array dimensions must be "
424 "the same): %s vs %s in dimension %d.",
425 ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
426 dimension);
427 }
428 }
429 element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
430 }
431
432 std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
433 arg_shape->dimensions().end());
434 for (size_t i = 1; i < arg_shapes.size(); ++i) {
435 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
436 }
437
438 Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
439
440 // Set dynamic dimensions if any input has dynamic dimension.
441 for (const Shape* shape : arg_shapes) {
442 for (int64 i = 0; i < shape->dimensions_size(); ++i) {
443 if (shape->is_dynamic_dimension(i)) {
444 result.set_dynamic_dimension(i, true);
445 }
446 }
447 }
448 return result;
449 }
450
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)451 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
452 const Shape& operand_shape, PrimitiveType new_element_type) {
453 auto old_element_type = operand_shape.element_type();
454 if (primitive_util::IsComplexType(old_element_type) &&
455 !primitive_util::IsComplexType(new_element_type)) {
456 return Unimplemented(
457 "Conversion from complex to real type %s => %s is not implemented.",
458 ShapeUtil::HumanString(operand_shape),
459 PrimitiveType_Name(new_element_type));
460 }
461 if (!operand_shape.IsArray() ||
462 !primitive_util::IsArrayType(new_element_type)) {
463 // Note: we may want to support tuple conversions via this operation in the
464 // future, by recursing into the tuple elements to check all sub-conversions
465 // are valid. For now we just reject them, though.
466 return InvalidArgument(
467 "Convert does not allow non-arrays, so cannot convert from %s to %s.",
468 ShapeUtil::HumanString(operand_shape),
469 PrimitiveType_Name(new_element_type));
470 }
471
472 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
473 }
474
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)475 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
476 const Shape& operand_shape, PrimitiveType new_element_type) {
477 auto old_element_type = operand_shape.element_type();
478 if (primitive_util::IsComplexType(old_element_type) !=
479 primitive_util::IsComplexType(new_element_type)) {
480 return InvalidArgument("Conversion from complex to real type %s => %s.",
481 ShapeUtil::HumanString(operand_shape),
482 PrimitiveType_Name(new_element_type));
483 }
484 if (!operand_shape.IsArray() ||
485 !primitive_util::IsArrayType(new_element_type)) {
486 // Note: we may want to support tuple conversions via this operation in the
487 // future, by recursing into the tuple elements to check all sub-conversions
488 // are valid. For now we just reject them, though.
489 return InvalidArgument(
490 "Cannot convert from or to tuple type; requested conversion: %s => %s.",
491 ShapeUtil::HumanString(operand_shape),
492 PrimitiveType_Name(new_element_type));
493 }
494 if (primitive_util::BitWidth(old_element_type) !=
495 primitive_util::BitWidth(new_element_type)) {
496 return InvalidArgument(
497 "Cannot bitcast types with different bit-widths: %s => %s.",
498 PrimitiveType_Name(old_element_type),
499 PrimitiveType_Name(new_element_type));
500 }
501
502 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
503 }
504
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)505 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
506 const Shape& operand_shape, const int exponent_bits,
507 const int mantissa_bits) {
508 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
509 return InvalidArgument(
510 "Expected element type in shape to be floating point for "
511 "ReducePrecision operation; got %s.",
512 PrimitiveType_Name(operand_shape.element_type()));
513 }
514 if (exponent_bits < 1) {
515 // One exponent bit is necessary to distinguish 0 from infinity. Having
516 // no exponent bits doesn't produce a sensible number, so we require at
517 // least one.
518 return InvalidArgument("Expected exponent_bits >= 1; got %d.",
519 exponent_bits);
520 }
521 if (mantissa_bits < 0) {
522 // A number with no mantissa bits is still meaningful, however.
523 return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
524 mantissa_bits);
525 }
526 return operand_shape;
527 }
528
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)529 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
530 const Shape& operand_shape, const Shape& padding_value_shape,
531 const PaddingConfig& padding_config) {
532 if (!operand_shape.IsArray()) {
533 return InvalidArgument(
534 "Pad operation does not support tuple-shape operands.");
535 }
536 if (!ShapeUtil::IsScalar(padding_value_shape)) {
537 return InvalidArgument(
538 "Pad operation does not support non-scalar padding values.");
539 }
540 if (operand_shape.rank() != padding_config.dimensions_size()) {
541 return InvalidArgument(
542 "The rank of the operand and the padding configuration do not match: "
543 "%s vs %s.",
544 ShapeUtil::HumanString(operand_shape),
545 padding_config.ShortDebugString());
546 }
547 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
548 padding_value_shape)) {
549 return InvalidArgument(
550 "The element types of the operands to Pad do not match.");
551 }
552 if (absl::c_any_of(padding_config.dimensions(),
553 [](const PaddingConfig::PaddingConfigDimension& p) {
554 return p.interior_padding() < 0;
555 })) {
556 return InvalidArgument("Interior padding cannot be negative: %s",
557 padding_config.ShortDebugString());
558 }
559
560 if (!padding_value_shape.is_static()) {
561 return InvalidArgument("Dynamic padding value is not supported");
562 }
563
564 std::vector<int64> dimensions(operand_shape.rank());
565 std::vector<bool> is_dynamic(operand_shape.rank());
566 for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
567 const auto& p = padding_config.dimensions(i);
568 dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
569 p.edge_padding_high() +
570 std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
571 p.interior_padding();
572 if (dimensions[i] < 0) {
573 return InvalidArgument("Padding result in negative size for dimension %d",
574 i);
575 }
576 is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
577 }
578
579 return ShapeUtil::MakeShape(
580 ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
581 dimensions, is_dynamic);
582 }
583
584 // Current DotDimensionNumbers Requirements:
585 //
586 // Contracting Dimensions:
587 // *) Same number of contracting dimensions on both lhs and rhs.
588 // *) Contracting dimension size must be the same on both lhs and rhs.
589 //
590 // Batch Dimensions:
591 // *) Same number of batch dimensions on both lhs and rhs.
592 // *) Same batch dimension sizes on both lhs and rhs.
593 //
594
595 namespace {
596
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)597 Status ValidateDotDimensionNumbers(
598 const Shape& lhs, const Shape& rhs,
599 const DotDimensionNumbers& dimension_numbers) {
600 // Check that dimension numbers are in range.
601 auto dims_in_range = [](const int64 rank,
602 absl::Span<const int64> contracting_dims,
603 absl::Span<const int64> batch_dims) -> bool {
604 auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
605 return absl::c_all_of(contracting_dims, in_range) &&
606 absl::c_all_of(batch_dims, in_range);
607 };
608
609 absl::Span<const int64> lhs_contracting_dimensions =
610 AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
611 absl::Span<const int64> rhs_contracting_dimensions =
612 AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
613 absl::Span<const int64> lhs_batch_dimensions =
614 AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
615 absl::Span<const int64> rhs_batch_dimensions =
616 AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
617
618 if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
619 lhs_batch_dimensions) ||
620 !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
621 rhs_batch_dimensions)) {
622 return InvalidArgument("A dimension number is out of range in Dot: %s.",
623 dimension_numbers.DebugString());
624 }
625
626 // Check that dimension numbers are unique.
627 auto dims_unique = [](absl::Span<const int64> contracting_dims,
628 absl::Span<const int64> batch_dims) -> bool {
629 absl::flat_hash_set<int64> dim_set;
630 auto is_unique = [&dim_set](int64 i) -> bool {
631 return dim_set.insert(i).second;
632 };
633 return absl::c_all_of(contracting_dims, is_unique) &&
634 absl::c_all_of(batch_dims, is_unique);
635 };
636
637 if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
638 !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
639 return InvalidArgument("A dimension number is not unique in Dot: %s.",
640 dimension_numbers.DebugString());
641 }
642
643 return Status::OK();
644 }
645
646 } // namespace
647
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers,absl::optional<PrimitiveType> preferred_element_type)648 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
649 const Shape& lhs, const Shape& rhs,
650 const DotDimensionNumbers& dimension_numbers,
651 absl::optional<PrimitiveType> preferred_element_type) {
652 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
653 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
654
655 auto fail = [lhs, rhs](const string& addendum) -> Status {
656 string message =
657 StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
658 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
659 if (!addendum.empty()) {
660 message += " " + addendum;
661 }
662 return InvalidArgument("%s", message);
663 };
664
665 // Validate basic properties of dot dimension numbers.
666 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
667
668 // Check that number of contracting dimensions match.
669 if (dimension_numbers.lhs_contracting_dimensions_size() !=
670 dimension_numbers.rhs_contracting_dimensions_size()) {
671 return fail(
672 "Must specify the same number of contracting dimensions for lhs and "
673 "rhs.");
674 }
675 // Check that contracting dimension sizes match.
676 for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
677 ++i) {
678 const int64 lhs_contracting_dimension =
679 dimension_numbers.lhs_contracting_dimensions(i);
680 const int64 rhs_contracting_dimension =
681 dimension_numbers.rhs_contracting_dimensions(i);
682 if (lhs.dimensions(lhs_contracting_dimension) !=
683 rhs.dimensions(rhs_contracting_dimension)) {
684 return fail("Contracting dimension sizes do not match.");
685 }
686 }
687
688 // Check that number of batch dimensions match.
689 if (dimension_numbers.lhs_batch_dimensions_size() !=
690 dimension_numbers.rhs_batch_dimensions_size()) {
691 return fail("Must the same number of batch dimensions for lhs and rhs.");
692 }
693
694 // Check that batch dimension numbers and sizes match.
695 for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
696 if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
697 rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
698 return fail("Batch dimension sizes must match for lhs/rhs.");
699 }
700 }
701
702 // The ranks of lhs and rhs are decremented by 1 respectively due to the
703 // contraction, and added for the rank of the result. When an input tensor is
704 // a scalar, its contribution to the rank of the result is 0.
705 // Generate the result dimensions in order, rhs dimensions followed by lhs
706 // dimensions except the contracted and batch dimensions.
707 std::vector<int64> dimensions;
708 std::vector<bool> is_dynamic;
709 for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
710 dimensions.push_back(lhs.dimensions(lhs_dim));
711 is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
712 }
713 for (int64 i = 0; i < lhs.rank(); i++) {
714 if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
715 i) &&
716 !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
717 dimensions.push_back(lhs.dimensions(i));
718 is_dynamic.push_back(lhs.is_dynamic_dimension(i));
719 }
720 }
721 for (int64 i = 0; i < rhs.rank(); i++) {
722 if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
723 i) &&
724 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
725 dimensions.push_back(rhs.dimensions(i));
726 is_dynamic.push_back(rhs.is_dynamic_dimension(i));
727 }
728 }
729 TF_ASSIGN_OR_RETURN(
730 PrimitiveType type,
731 MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
732 preferred_element_type));
733 Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
734
735 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
736 VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
737 return result;
738 }
739
740 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)741 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
742 const Shape& lhs,
743 const Shape& rhs) {
744 TF_RET_CHECK(lhs.rank() == rhs.rank());
745
746 // The shapes have to be compatible. That is, if some dimension d has a
747 // different size in the two shapes, one of them has to be 1 (a "degenerate"
748 // dimension). In that case, the output shape has the non-1 dimension size
749 // from the lhs/rhs pair in every index.
750 std::vector<int64> output_dimensions(lhs.rank());
751 std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
752 for (int64 i = 0; i < lhs.rank(); ++i) {
753 if (lhs.dimensions(i) == rhs.dimensions(i)) {
754 output_dimensions[i] = lhs.dimensions(i);
755 } else if (lhs.dimensions(i) == 1) {
756 output_dimensions[i] = rhs.dimensions(i);
757 } else if (rhs.dimensions(i) == 1) {
758 output_dimensions[i] = lhs.dimensions(i);
759 } else {
760 return InvalidArgument(
761 "Binary op %s with incompatible shapes: %s and %s.",
762 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
763 ShapeUtil::HumanString(rhs));
764 }
765 }
766
767 // Merge dynamic dimensions from two shapes.
768 for (int64 i = 0; i < rhs.rank(); ++i) {
769 if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
770 output_dimensions_is_dynamic[i] = true;
771 }
772 }
773
774 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
775 output_dimensions, output_dimensions_is_dynamic);
776 }
777
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)778 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
779 const Shape& smaller_shape, const Shape& larger_shape,
780 absl::Span<const int64> broadcast_dimensions) {
781 if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
782 // Reject "magic" inference for binops on different shapes, requiring
783 // the user to provide an explicit broadcast dimension in this case.
784 // See b/25177275 for more details.
785 return InvalidArgument("Automatic shape inference not supported: %s and %s",
786 ShapeUtil::HumanString(smaller_shape),
787 ShapeUtil::HumanString(larger_shape));
788 } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
789 return InvalidArgument(
790 "Size of broadcast_dimensions has to match lower-rank operand's "
791 "rank; "
792 " lower-rank operand's rank is %d, size of broadcast_dimensions is "
793 "%u.",
794 smaller_shape.rank(), broadcast_dimensions.size());
795 }
796
797 // broadcast_dimensions is a sequence of dimensions; its length is equal to
798 // the rank of the lower-rank operand. The lower-rank operand's dimensions
799 // have to be compatible with the higher-rank operand's dimensions at indices
800 // specified by broadcast_dimensions. Here compatible means the dimension
801 // sizes are equal or in one of the shapes the dimension size is
802 // one. Examples:
803 //
804 // smaller_shape larger_shape broadcast_dimensions output_shape
805 // [] [2, 3] {} [2, 3]
806 // [3] [4, 3] {1} [4, 3]
807 // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
808 // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
809 // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
810 //
811 // The column output_shape may not be the final shape of the XLA
812 // operation. After the "InDim" broadcasting implemented in this function
813 // expands the rank, degenerate-dimension broadcasting (implemented in
814 // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
815 // up to match the dimension size of the other operand. For example, consider
816 // the row in the table above with a smaller_shape of [2, 1]. The shape
817 // returned by this function is [2, 3, 1] (output_shape) however, the result
818 // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
819 // broadcasting.
820 //
821 // Invalid broadcasts:
822 //
823 // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
824 // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
825 // dimension zero of smaller_shape(size 3). **Zero here comes from the value
826 // in broadcast_dimensions.
827 //
828 // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
829 // Reason: Dimension one of larger_shape (size 3) is not compatible with
830 // dimension zero of smaller_shape(size 2)
831
832 // The output shape is initially the larger_shape. Sizes of dimensions
833 // specified in broadcast_dimensions are then changed to match the
834 // corresponding dimension size in smaller_shape.
835 Shape output_shape(larger_shape);
836 output_shape.set_element_type(
837 ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
838
839 for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
840 int64 dimension_to_match = broadcast_dimensions.at(i);
841 if (dimension_to_match < 0) {
842 return InvalidArgument(
843 "Broadcast dimension number (%d) cannot be negative.",
844 dimension_to_match);
845 }
846 if (dimension_to_match >= larger_shape.dimensions_size()) {
847 return InvalidArgument(
848 "Broadcast dimension number (%d) too large; higher-rank "
849 "operand has rank %d.",
850 dimension_to_match, larger_shape.dimensions_size());
851 }
852 int64 small_dimension_size = smaller_shape.dimensions(i);
853 int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
854 bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
855 bool large_is_dynamic =
856 larger_shape.is_dynamic_dimension(dimension_to_match);
857 // Dimension sizes must be compatible: match or be degenerate (degenerate
858 // case is handled by degenerate dimension broadcasting which occurs after
859 // InDim broadcasting).
860 if (small_dimension_size != large_dimension_size &&
861 small_dimension_size != 1 && large_dimension_size != 1) {
862 return InvalidArgument(
863 "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
864 small_dimension_size, large_dimension_size,
865 ShapeUtil::HumanString(smaller_shape),
866 ShapeUtil::HumanString(larger_shape));
867 }
868 if (small_is_dynamic != large_is_dynamic) {
869 if (small_dimension_size == large_dimension_size ||
870 (small_dimension_size == 1 && !small_is_dynamic) ||
871 (large_dimension_size == 1 && !large_is_dynamic)) {
872 // Do nothing. It's OK when the size-1 dimension is not static.
873 } else {
874 return InvalidArgument(
875 "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
876 ShapeUtil::HumanString(smaller_shape),
877 ShapeUtil::HumanString(larger_shape));
878 }
879 }
880 // Make sure the broadcast dimensions are listed in a strictly increasing
881 // order.
882 if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
883 return InvalidArgument(
884 "Broadcast dimensions order is wrong: %d comes after %d.",
885 dimension_to_match, broadcast_dimensions.at(i - 1));
886 }
887
888 output_shape.set_dimensions(dimension_to_match, small_dimension_size);
889 output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
890 }
891
892 return output_shape;
893 }
894
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)895 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
896 HloOpcode operation, const Shape& lhs, const Shape& rhs,
897 absl::Span<const int64> broadcast_dimensions) {
898 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
899 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
900
901 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
902 return InvalidArgument(
903 "Binary op %s with different element types: %s and %s.",
904 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
905 ShapeUtil::HumanString(rhs));
906 }
907
908 if (lhs.rank() == rhs.rank()) {
909 std::vector<int64> identity_dims(lhs.rank());
910 std::iota(identity_dims.begin(), identity_dims.end(), 0);
911 if (!broadcast_dimensions.empty() &&
912 broadcast_dimensions != identity_dims) {
913 return InvalidArgument(
914 "Broadcast dimensions field must either be not set or be the "
915 "identity on binary operations with operands of the same rank.");
916 }
917 }
918
919 if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
920 // If the shapes are the same other than layout, the output shape is the
921 // same (elementwise op).
922 Shape result = ShapeUtil::ChangeElementType(
923 lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
924
925 for (int64 i = 0; i < rhs.rank(); ++i) {
926 if (rhs.is_dynamic_dimension(i)) {
927 result.set_dynamic_dimension(i, true);
928 }
929 }
930
931 return result;
932
933 } else if (lhs.rank() == rhs.rank()) {
934 return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
935 } else {
936 // Ranks do not match, so perform InDim broadcasting using
937 // broadcast_dimensions. Scalar broadcasting is a special case of this.
938 const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
939 const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
940
941 // After InDim broadcasting, perform degenerate dimensions broadcasting.
942 TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
943 InferInDimBroadcastShape(smaller_shape, larger_shape,
944 broadcast_dimensions));
945
946 return InferDegenerateDimensionBroadcastShape(
947 operation, indim_broadcast_shape, larger_shape);
948 }
949 }
950
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)951 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
952 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
953 return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
954 /*broadcast_dimensions=*/{});
955 }
956
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)957 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
958 HloOpcode opcode, const Shape& lhs, const Shape& rhs,
959 absl::Span<const int64> broadcast_dimensions) {
960 VLOG(2) << StrFormat(
961 "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
962 HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
963 ShapeUtil::HumanStringWithLayout(rhs),
964 StrJoin(broadcast_dimensions, ", "));
965
966 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
967 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
968
969 TF_RETURN_IF_ERROR(ExpectArray(
970 lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
971 TF_RETURN_IF_ERROR(ExpectArray(
972 rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
973 switch (opcode) {
974 case HloOpcode::kAdd:
975 case HloOpcode::kMaximum:
976 case HloOpcode::kMinimum:
977 case HloOpcode::kMultiply:
978 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
979 broadcast_dimensions);
980
981 case HloOpcode::kSubtract:
982 case HloOpcode::kAtan2:
983 case HloOpcode::kPower:
984 case HloOpcode::kDivide:
985 case HloOpcode::kRemainder:
986 case HloOpcode::kShiftLeft:
987 case HloOpcode::kShiftRightArithmetic:
988 case HloOpcode::kShiftRightLogical:
989 if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
990 return InvalidArgument(
991 "Expected element type in shape to be arithmetic type for "
992 "operation %s; got PRED.",
993 HloOpcodeString(opcode));
994 }
995 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
996 broadcast_dimensions);
997
998 case HloOpcode::kComplex: {
999 if (!ShapeUtil::ElementIsFloating(lhs)) {
1000 return InvalidArgument(
1001 "Expected element type in shape to be floating for complex compose "
1002 "operation; got %s.",
1003 PrimitiveType_Name(lhs.element_type()));
1004 }
1005 TF_ASSIGN_OR_RETURN(const Shape& shape,
1006 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1007 broadcast_dimensions));
1008 if (lhs.element_type() == F32 && rhs.element_type() == F32) {
1009 return ShapeUtil::ChangeElementType(shape, C64);
1010 } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
1011 return ShapeUtil::ChangeElementType(shape, C128);
1012 } else {
1013 return Unimplemented("Complex component type is not implemented.");
1014 }
1015 }
1016 case HloOpcode::kAnd:
1017 case HloOpcode::kOr:
1018 case HloOpcode::kXor:
1019 if (lhs.element_type() != PRED &&
1020 !primitive_util::IsIntegralType(lhs.element_type())) {
1021 return InvalidArgument(
1022 "Expected pred or integral type in argument to and/or operation; "
1023 "got %s.",
1024 PrimitiveType_Name(lhs.element_type()));
1025 }
1026 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1027 broadcast_dimensions);
1028 case HloOpcode::kCompare: {
1029 TF_ASSIGN_OR_RETURN(const Shape& shape,
1030 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1031 broadcast_dimensions));
1032 return ShapeUtil::ChangeElementType(shape, PRED);
1033 }
1034 default:
1035 return Unimplemented(
1036 "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1037 HloOpcodeString(opcode), lhs.ShortDebugString(),
1038 rhs.ShortDebugString());
1039 }
1040 }
1041
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1042 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1043 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1044 const HloInstruction* ehs) {
1045 return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1046 }
1047
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1048 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1049 HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1050 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1051 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1052 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1053 switch (opcode) {
1054 case HloOpcode::kClamp:
1055 return InferClampShape(lhs, rhs, ehs);
1056 case HloOpcode::kSelect:
1057 return InferSelectShape(lhs, rhs, ehs);
1058 case HloOpcode::kTupleSelect:
1059 return InferTupleSelectShape(lhs, rhs, ehs);
1060 default:
1061 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1062 }
1063 }
1064
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1065 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1066 HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1067 std::vector<const Shape*> operand_shapes;
1068 operand_shapes.reserve(operands.size());
1069 for (const HloInstruction* operand : operands) {
1070 operand_shapes.push_back(&operand->shape());
1071 }
1072 return InferVariadicOpShape(opcode, operand_shapes);
1073 }
1074
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1075 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1076 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1077 for (const Shape* shape : operand_shapes) {
1078 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1079 }
1080 switch (opcode) {
1081 case HloOpcode::kTuple: {
1082 Shape result = ShapeUtil::MakeTupleShape({});
1083 result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1084 for (const Shape* shape : operand_shapes) {
1085 ShapeUtil::AppendShapeToTuple(*shape, &result);
1086 }
1087 return result;
1088 }
1089 case HloOpcode::kSort: {
1090 if (operand_shapes.size() == 1) {
1091 return *operand_shapes[0];
1092 } else {
1093 for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
1094 if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1095 *operand_shapes[operand])) {
1096 return InvalidArgument(
1097 "Sort keys and values dimensions must match. "
1098 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1099 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1100 ShapeUtil::HumanString(*operand_shapes[operand]));
1101 }
1102 }
1103 std::vector<Shape> operand_shape_values;
1104 for (const Shape* operand_shape : operand_shapes) {
1105 operand_shape_values.push_back(*operand_shape);
1106 }
1107 return ShapeUtil::MakeTupleShape(operand_shape_values);
1108 }
1109 return InvalidArgument("Unexpected number of operands for sort");
1110 }
1111 default:
1112 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1113 }
1114 }
1115
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1116 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1117 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1118 absl::Span<const int64> dimensions) {
1119 if (arg_shapes.empty()) {
1120 return InvalidArgument("Map expects at least one argument.");
1121 }
1122
1123 // All arguments must have the same shape.
1124 const Shape* arg_shape = arg_shapes[0];
1125 for (size_t i = 1; i < arg_shapes.size(); ++i) {
1126 TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1127
1128 if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1129 continue;
1130 }
1131 if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1132 *arg_shape)) {
1133 if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1134 continue;
1135 }
1136 if (ShapeUtil::IsScalar(*arg_shape)) {
1137 arg_shape = arg_shapes[i];
1138 continue;
1139 }
1140 }
1141
1142 std::vector<string> pieces;
1143 for (const Shape* shape : arg_shapes) {
1144 pieces.push_back(ShapeUtil::HumanString(*shape));
1145 }
1146 return InvalidArgument(
1147 "Map operation requires all operands to have the same shape; got: "
1148 "%s.",
1149 StrJoin(pieces, ", "));
1150 }
1151
1152 // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1153 // only support mapping across all dimensions: i.e. scalar map functions).
1154 if (dimensions.size() != arg_shape->dimensions_size()) {
1155 return InvalidArgument(
1156 "Map applied to a subset of dimensions currently not supported: "
1157 "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1158 arg_shape->dimensions_size(), dimensions.size());
1159 }
1160
1161 // Check that requested map dimensions numbers are monotonically increasing.
1162 for (int i = 0; i < dimensions.size(); ++i) {
1163 if (dimensions[i] != i) {
1164 return InvalidArgument(
1165 "Map requires monotonically increasing dimension numbers; got: %s.",
1166 StrJoin(dimensions, ", "));
1167 }
1168 }
1169
1170 // The applied function's arity equals the number of arguments.
1171 if (arg_shapes.size() != to_apply.parameters_size()) {
1172 return InvalidArgument(
1173 "Map applied function arity must match number of arguments; got: "
1174 "arity: %d, arguments: %u.",
1175 to_apply.parameters_size(), arg_shapes.size());
1176 }
1177
1178 // The parameters should all be scalars, and the output too.
1179 const Shape& output_shape = to_apply.result();
1180 if (!ShapeUtil::IsScalar(output_shape)) {
1181 return InvalidArgument(
1182 "Mapped computation's result has to be a scalar; got: %s.",
1183 ShapeUtil::HumanString(output_shape));
1184 }
1185
1186 for (int i = 0; i < to_apply.parameters_size(); ++i) {
1187 const Shape& parameter_shape = to_apply.parameters(i);
1188
1189 if (!ShapeUtil::IsScalar(parameter_shape)) {
1190 return InvalidArgument(
1191 "Mapped computation's parameter has to be a scalar; "
1192 "got parameter %d shape: %s.",
1193 i, ShapeUtil::HumanString(parameter_shape));
1194 }
1195
1196 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1197 *arg_shape)) {
1198 return InvalidArgument(
1199 "Mapped computation's parameter type has to match argument element "
1200 "type; got parameter %d shape: %s, argument shape: %s.",
1201 i, ShapeUtil::HumanString(parameter_shape),
1202 ShapeUtil::HumanString(*arg_shape));
1203 }
1204 }
1205
1206 return ShapeUtil::MakeShape(output_shape.element_type(),
1207 AsInt64Slice(arg_shape->dimensions()));
1208 }
1209
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1210 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1211 const Shape& operand_shape, const Shape& scale_shape,
1212 const Shape& offset_shape, int64 feature_index) {
1213 TF_RETURN_IF_ERROR(
1214 ExpectArray(operand_shape, "operand of batch norm training"));
1215 TF_RETURN_IF_ERROR(
1216 ExpectArray(offset_shape, "offset input of batch norm training"));
1217 TF_RETURN_IF_ERROR(
1218 ExpectArray(scale_shape, "scale input of batch norm training"));
1219
1220 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1221 Status::OK());
1222 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1223 Status::OK());
1224 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1225 Status::OK());
1226
1227 if (feature_index >= operand_shape.rank()) {
1228 return InvalidArgument(
1229 "Expected feature_index of batch-norm-training to be "
1230 "smaller than the rank of operand_shape; "
1231 "got feature_index %d, and rank %d.",
1232 feature_index, operand_shape.rank());
1233 }
1234
1235 if (feature_index < 0) {
1236 return InvalidArgument(
1237 "Expected feature_index of batch-norm-training to "
1238 "be a non-negative number, got %d.",
1239 feature_index);
1240 }
1241
1242 if (operand_shape.rank() < 1) {
1243 return InvalidArgument(
1244 "Expected the rank of operand to "
1245 "batch-norm-training to be at least 1; got %d.",
1246 operand_shape.rank());
1247 }
1248
1249 if (offset_shape.rank() != 1) {
1250 return InvalidArgument(
1251 "Offset input of batch-norm-training must have"
1252 " rank 1, but has rank %d.",
1253 offset_shape.rank());
1254 }
1255
1256 if (scale_shape.rank() != 1) {
1257 return InvalidArgument(
1258 "Scale input of batch-norm-training must have"
1259 " rank 1, but has rank %d.",
1260 scale_shape.rank());
1261 }
1262
1263 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1264 return InvalidArgument(
1265 "The operand to batch-norm-training must have a floating point "
1266 "element type, but the shape is %s.",
1267 PrimitiveType_Name(operand_shape.element_type()));
1268 }
1269
1270 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1271 operand_shape)) {
1272 return InvalidArgument(
1273 "The inputs should have the same element type for batch-norm-training, "
1274 "but the shape of offset factor is %s "
1275 "and the shape of operand is %s.",
1276 PrimitiveType_Name(offset_shape.element_type()),
1277 PrimitiveType_Name(operand_shape.element_type()));
1278 }
1279
1280 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1281 operand_shape)) {
1282 return InvalidArgument(
1283 "The inputs should have the same element type for batch-norm-training, "
1284 "but the shape of scale factor is %s "
1285 "and the shape of operand is %s.",
1286 PrimitiveType_Name(scale_shape.element_type()),
1287 PrimitiveType_Name(operand_shape.element_type()));
1288 }
1289
1290 const int64 feature_count = operand_shape.dimensions(feature_index);
1291 Shape output_shape_for_mean_and_var =
1292 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1293
1294 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1295 return InvalidArgument(
1296 "The size of offset factor should be the same as feature count,"
1297 "but the size of offset factor is %d "
1298 "and the feature count is %d.",
1299 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1300 }
1301
1302 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1303 return InvalidArgument(
1304 "The size of scale factor should be the same as feature count,"
1305 "but the size of scale factor is %d "
1306 "and the feature count is %d.",
1307 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1308 }
1309
1310 return ShapeUtil::MakeTupleShape({operand_shape,
1311 output_shape_for_mean_and_var,
1312 output_shape_for_mean_and_var});
1313 }
1314
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1315 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1316 const Shape& operand_shape, const Shape& scale_shape,
1317 const Shape& offset_shape, const Shape& mean_shape,
1318 const Shape& variance_shape, int64 feature_index) {
1319 TF_RETURN_IF_ERROR(
1320 ExpectArray(operand_shape, "operand of batch norm inference"));
1321 TF_RETURN_IF_ERROR(
1322 ExpectArray(offset_shape, "offset input of batch norm inference"));
1323 TF_RETURN_IF_ERROR(
1324 ExpectArray(scale_shape, "scale input of batch norm inference"));
1325
1326 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1327 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1328 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1329 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1330 TF_RETURN_IF_ERROR(
1331 ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1332
1333 if (feature_index >= operand_shape.rank()) {
1334 return InvalidArgument(
1335 "Expected feature_index of batch-norm-inference to be "
1336 "smaller than the rank of operand_shape; "
1337 "got feature_index %d, and rank %d.",
1338 feature_index, operand_shape.rank());
1339 }
1340
1341 if (feature_index < 0) {
1342 return InvalidArgument(
1343 "Expected feature_index of batch-norm-inference to "
1344 "be a non-negative number, got %d.",
1345 feature_index);
1346 }
1347
1348 if (operand_shape.rank() < 1) {
1349 return InvalidArgument(
1350 "Expected the rank of operand to "
1351 "batch-norm-inference to be at least 1; got %d.",
1352 operand_shape.rank());
1353 }
1354
1355 if (offset_shape.rank() != 1) {
1356 return InvalidArgument(
1357 "Offset input of batch-norm-inference must have"
1358 " rank 1, but has rank %d.",
1359 offset_shape.rank());
1360 }
1361
1362 if (scale_shape.rank() != 1) {
1363 return InvalidArgument(
1364 "Scale input of batch-norm-inference must have"
1365 " rank 1, but has rank %d.",
1366 scale_shape.rank());
1367 }
1368
1369 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1370 return InvalidArgument(
1371 "The operand to batch-norm-inference must have a floating point "
1372 "element type, but the shape is %s.",
1373 PrimitiveType_Name(operand_shape.element_type()));
1374 }
1375
1376 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1377 operand_shape)) {
1378 return InvalidArgument(
1379 "The inputs should have the same element type for "
1380 "batch-norm-inference, "
1381 "but the shape of offset factor is %s "
1382 "and the shape of operand is %s.",
1383 PrimitiveType_Name(offset_shape.element_type()),
1384 PrimitiveType_Name(operand_shape.element_type()));
1385 }
1386
1387 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1388 operand_shape)) {
1389 return InvalidArgument(
1390 "The inputs should have the same element type for "
1391 "batch-norm-inference, "
1392 "but the shape of scale factor is %s "
1393 "and the shape of operand is %s.",
1394 PrimitiveType_Name(scale_shape.element_type()),
1395 PrimitiveType_Name(operand_shape.element_type()));
1396 }
1397
1398 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1399 operand_shape)) {
1400 return InvalidArgument(
1401 "The inputs should have the same element type for "
1402 "batch-norm-inference, "
1403 "but the shape of mean is %s "
1404 "and the shape of operand is %s.",
1405 PrimitiveType_Name(mean_shape.element_type()),
1406 PrimitiveType_Name(operand_shape.element_type()));
1407 }
1408
1409 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1410 operand_shape)) {
1411 return InvalidArgument(
1412 "The inputs should have the same element type for "
1413 "batch-norm-inference, "
1414 "but the shape of variance is %s "
1415 "and the shape of operand is %s.",
1416 PrimitiveType_Name(mean_shape.element_type()),
1417 PrimitiveType_Name(variance_shape.element_type()));
1418 }
1419
1420 const int64 feature_count = operand_shape.dimensions(feature_index);
1421 Shape output_shape_for_mean_and_var =
1422 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1423
1424 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1425 return InvalidArgument(
1426 "The size of offset factor should be the same as feature count,"
1427 "but the size of offset factor is %d "
1428 "and the feature count is %d.",
1429 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1430 }
1431
1432 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1433 return InvalidArgument(
1434 "The size of scale factor should be the same as feature count,"
1435 "but the size of scale factor is %d "
1436 "and the feature count is %d.",
1437 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1438 }
1439
1440 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1441 return InvalidArgument(
1442 "The size of mean should be the same as feature count,"
1443 "but the size of mean is %d "
1444 "and the feature count is %d.",
1445 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1446 }
1447
1448 if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1449 return InvalidArgument(
1450 "The size of variance should be the same as feature count,"
1451 "but the size of variance is %d "
1452 "and the feature count is %d.",
1453 ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1454 }
1455
1456 return operand_shape;
1457 }
1458
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1459 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1460 const Shape& operand_shape, const Shape& scale_shape,
1461 const Shape& mean_shape, const Shape& var_shape,
1462 const Shape& output_grad_shape, int64 feature_index) {
1463 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1464 TF_RETURN_IF_ERROR(
1465 ExpectArray(scale_shape, "scale input of batch norm grad"));
1466 TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1467 TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1468 TF_RETURN_IF_ERROR(
1469 ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1470
1471 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1472 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1473 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1474 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1475 TF_RETURN_IF_ERROR(
1476 ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1477
1478 if (feature_index >= operand_shape.rank()) {
1479 return InvalidArgument(
1480 "Expected feature_index of batch-norm-grad to be "
1481 "smaller than the rank of operand_shape; "
1482 "got feature_index %d, and rank %d.",
1483 feature_index, operand_shape.rank());
1484 }
1485
1486 if (operand_shape.rank() != output_grad_shape.rank()) {
1487 return InvalidArgument(
1488 "Expected operand_shape of batch-norm-grad to have the same rank as"
1489 " output_grad_shape; got rank(oprand_shape) %d, and"
1490 " rank(output_grad_shape) %d.",
1491 operand_shape.rank(), output_grad_shape.rank());
1492 }
1493
1494 if (mean_shape.rank() != 1) {
1495 return InvalidArgument(
1496 "Mean input of batch-norm-grad must have"
1497 " rank 1, but has rank %d.",
1498 mean_shape.rank());
1499 }
1500
1501 if (scale_shape.rank() != 1) {
1502 return InvalidArgument(
1503 "Scale input of batch-norm-grad must have"
1504 " rank 1, but has rank %d.",
1505 scale_shape.rank());
1506 }
1507
1508 if (var_shape.rank() != 1) {
1509 return InvalidArgument(
1510 "Var input of batch-norm-grad must have"
1511 " rank 1, but has rank %d.",
1512 var_shape.rank());
1513 }
1514
1515 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1516 return InvalidArgument(
1517 "The operand to batch-norm-grad must have a floating point "
1518 "element type, but the shape is %s.",
1519 PrimitiveType_Name(operand_shape.element_type()));
1520 }
1521
1522 if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1523 return InvalidArgument(
1524 "The output_grad to batch-norm-grad must have a floating point "
1525 "element type, but the shape is %s.",
1526 PrimitiveType_Name(output_grad_shape.element_type()));
1527 }
1528
1529 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1530 operand_shape)) {
1531 return InvalidArgument(
1532 "The inputs should have the same element type for batch-norm-grad, "
1533 "but the element type of output_grad is %s "
1534 "and the element type of operand is %s.",
1535 PrimitiveType_Name(output_grad_shape.element_type()),
1536 PrimitiveType_Name(operand_shape.element_type()));
1537 }
1538
1539 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1540 operand_shape)) {
1541 return InvalidArgument(
1542 "The inputs should have the same element type for batch-norm-grad, "
1543 "but the element type of scale factor is %s "
1544 "and the element type of operand is %s.",
1545 PrimitiveType_Name(scale_shape.element_type()),
1546 PrimitiveType_Name(operand_shape.element_type()));
1547 }
1548
1549 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1550 operand_shape)) {
1551 return InvalidArgument(
1552 "The inputs should have the same element type for batch-norm-grad, "
1553 "but the element type of mean is %s "
1554 "and the element type of operand is %s.",
1555 PrimitiveType_Name(mean_shape.element_type()),
1556 PrimitiveType_Name(operand_shape.element_type()));
1557 }
1558
1559 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1560 operand_shape)) {
1561 return InvalidArgument(
1562 "The inputs should have the same element type for batch-norm-grad, "
1563 "but the element type of mean is %s "
1564 "and the element type of operand is %s.",
1565 PrimitiveType_Name(mean_shape.element_type()),
1566 PrimitiveType_Name(operand_shape.element_type()));
1567 }
1568
1569 const int64 feature_count = operand_shape.dimensions(feature_index);
1570
1571 Shape feature_shape =
1572 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1573
1574 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1575 return InvalidArgument(
1576 "The size of mean should be the same as feature count,"
1577 "but the size of offset factor is %d "
1578 "and the feature count is %d.",
1579 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1580 }
1581
1582 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1583 return InvalidArgument(
1584 "The size of scale factor should be the same as feature count,"
1585 "but the size of scale factor is %d "
1586 "and the feature count is %d.",
1587 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1588 }
1589
1590 if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1591 return InvalidArgument(
1592 "The size of variance should be the same as feature count,"
1593 "but the size of variance is %d "
1594 "and the feature count is %d.",
1595 ShapeUtil::GetDimension(var_shape, 0), feature_count);
1596 }
1597
1598 // Verify operand_shape and output_grad_shape have same bounds.
1599 for (int64 i = 0; i < operand_shape.rank(); ++i) {
1600 if (ShapeUtil::GetDimension(operand_shape, i) !=
1601 ShapeUtil::GetDimension(output_grad_shape, i)) {
1602 return InvalidArgument(
1603 "The bounds of operand shape should be the same as output_grad's,"
1604 "but the bound of operand_shape at dimension %d is %d "
1605 "and the bound of output_grad_shape is %d.",
1606 i, ShapeUtil::GetDimension(operand_shape, i),
1607 ShapeUtil::GetDimension(output_grad_shape, i));
1608 }
1609 }
1610
1611 return ShapeUtil::MakeTupleShape(
1612 {operand_shape, feature_shape, feature_shape});
1613 }
1614
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums,absl::optional<PrimitiveType> preferred_element_type)1615 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1616 const Shape& lhs, const Shape& rhs, int64 feature_group_count,
1617 int64 batch_group_count, const Window& window,
1618 const ConvolutionDimensionNumbers& dnums,
1619 absl::optional<PrimitiveType> preferred_element_type) {
1620 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1621 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1622
1623 if (feature_group_count <= 0) {
1624 return InvalidArgument(
1625 "feature_group_count must be a positive number, got %d",
1626 feature_group_count);
1627 }
1628
1629 if (batch_group_count <= 0) {
1630 return InvalidArgument(
1631 "batch_group_count must be a positive number, got %d",
1632 batch_group_count);
1633 }
1634
1635 if (batch_group_count > 1 && feature_group_count > 1) {
1636 return InvalidArgument(
1637 "both batch_group_count %d and feature_group_count %d cannot be "
1638 "greater than 1",
1639 batch_group_count, feature_group_count);
1640 }
1641
1642 if (dnums.input_spatial_dimensions_size() !=
1643 dnums.kernel_spatial_dimensions_size()) {
1644 return InvalidArgument(
1645 "Both arguments to convolution must have same number of dimensions.\n"
1646 "Numbers: %s",
1647 dnums.DebugString());
1648 }
1649
1650 if (dnums.input_spatial_dimensions_size() !=
1651 dnums.output_spatial_dimensions_size()) {
1652 return InvalidArgument(
1653 "Both input and output of convolution must have same number of "
1654 "dimensions.\nNumbers: %s",
1655 dnums.DebugString());
1656 }
1657
1658 const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1659 if (window.dimensions_size() != num_spatial_dims) {
1660 return InvalidArgument(
1661 "Window must have same number of dimensions as dimension numbers.\n"
1662 "Window: %s\nDimension numbers: %s.",
1663 window.DebugString(), dnums.DebugString());
1664 }
1665
1666 const int num_dims = num_spatial_dims + 2;
1667 if (lhs.rank() != num_dims) {
1668 return InvalidArgument(
1669 "The LHS argument to a convolution should have rank %d; lhs: %s.",
1670 num_dims, ShapeUtil::HumanString(lhs));
1671 }
1672 if (rhs.rank() != num_dims) {
1673 return InvalidArgument(
1674 "The RHS argument to a convolution should have rank %d; rhs: %s.",
1675 num_dims, ShapeUtil::HumanString(rhs));
1676 }
1677 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1678 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1679
1680 // Verifies that the input and window dimensions are a permutation of
1681 // the dimension numbers.
1682 std::vector<int64> input_dnums(num_dims);
1683 input_dnums[0] = dnums.input_batch_dimension();
1684 input_dnums[1] = dnums.input_feature_dimension();
1685 std::copy(dnums.input_spatial_dimensions().begin(),
1686 dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1687 absl::c_sort(input_dnums);
1688
1689 std::vector<int64> window_dnums(num_dims);
1690 window_dnums[0] = dnums.kernel_input_feature_dimension();
1691 window_dnums[1] = dnums.kernel_output_feature_dimension();
1692 std::copy(dnums.kernel_spatial_dimensions().begin(),
1693 dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1694 absl::c_sort(window_dnums);
1695
1696 std::vector<int64> output_dnums(num_dims);
1697 output_dnums[0] = dnums.output_batch_dimension();
1698 output_dnums[1] = dnums.output_feature_dimension();
1699 std::copy(dnums.output_spatial_dimensions().begin(),
1700 dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1701 absl::c_sort(output_dnums);
1702
1703 std::vector<int64> expected_dnums(num_dims);
1704 std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1705
1706 const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1707 if (!absl::c_all_of(input_dnums, in_range) ||
1708 !absl::c_all_of(window_dnums, in_range) ||
1709 !absl::c_all_of(output_dnums, in_range)) {
1710 return InvalidArgument(
1711 "A dimension number is out of range in convolution: %s.",
1712 dnums.DebugString());
1713 }
1714
1715 if (input_dnums != expected_dnums) {
1716 return InvalidArgument(
1717 "Input dimensions of convolution must contain each dimension exactly "
1718 "once: %s.",
1719 dnums.DebugString());
1720 }
1721 if (window_dnums != expected_dnums) {
1722 return InvalidArgument(
1723 "Window dimensions of convolution must contain each dimension exactly "
1724 "once: %s.",
1725 dnums.DebugString());
1726 }
1727 if (output_dnums != expected_dnums) {
1728 return InvalidArgument(
1729 "Output dimensions of convolution must contain each dimension exactly "
1730 "once: %s.",
1731 dnums.DebugString());
1732 }
1733
1734 std::vector<int64> input_spatial_dims(num_spatial_dims);
1735 for (int i = 0; i < num_spatial_dims; ++i) {
1736 input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1737 }
1738 const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1739 const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1740
1741 std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1742 for (int i = 0; i < num_spatial_dims; ++i) {
1743 kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1744 }
1745 const int64 kernel_input_features =
1746 rhs.dimensions(dnums.kernel_input_feature_dimension());
1747 const int64 kernel_output_features =
1748 rhs.dimensions(dnums.kernel_output_feature_dimension());
1749
1750 if (kernel_output_features % batch_group_count != 0) {
1751 return InvalidArgument(
1752 "Expected output feature dimension size (value %d) to be a multiple of "
1753 "batch group count %d; got <conv>(%s, %s)\n"
1754 "Dimension numbers: {%s}.",
1755 kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1756 ShapeUtil::HumanString(rhs), dnums.DebugString());
1757 }
1758
1759 if (input_features % feature_group_count != 0 ||
1760 input_features / feature_group_count != kernel_input_features) {
1761 return InvalidArgument(
1762 "Expected LHS feature dimension (value %d) to be a multiple of "
1763 "feature_group_count (value %d), and LHS feature dimension / "
1764 "feature_group_count = RHS feature dimension (value %d); "
1765 "got <conv>(%s, %s)\n"
1766 "Dimension numbers: {%s}.",
1767 input_features, feature_group_count, kernel_input_features,
1768 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1769 dnums.DebugString());
1770 }
1771
1772 if (kernel_output_features % feature_group_count > 0) {
1773 // A depthwise/grouped filter has the shape
1774 // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1775 // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1776 // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1777 // feature count (which is equal to kernel output features) has to be a
1778 // multiple of feature_group_count.
1779 return InvalidArgument(
1780 "Expected output feature dimension (value %d) to be divisible by "
1781 "feature_group_count (value %d); "
1782 "got <conv>(%s, %s)\n"
1783 "Dimension numbers: {%s}.",
1784 kernel_output_features, feature_group_count,
1785 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1786 dnums.DebugString());
1787 }
1788
1789 if (input_batch % batch_group_count != 0) {
1790 return InvalidArgument(
1791 "Expected input batch dimension (value %d) to be divisible by "
1792 "batch_group_count (value %d); "
1793 "got <conv>(%s, %s)\n"
1794 "Dimension numbers: {%s}.",
1795 input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1796 ShapeUtil::HumanString(rhs), dnums.DebugString());
1797 }
1798
1799 std::vector<int64> window_dims(num_spatial_dims);
1800 for (int i = 0; i < num_spatial_dims; ++i) {
1801 window_dims[i] = window.dimensions(i).size();
1802 }
1803 if (kernel_spatial_dims != window_dims) {
1804 return InvalidArgument(
1805 "Window dimensions do not match RHS shape:\n\t"
1806 "RHS shape: %s\n\t"
1807 "Window: {%s}\n\t"
1808 "Dimension numbers: {%s}.",
1809 ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1810 dnums.ShortDebugString());
1811 }
1812
1813 Shape base_shape =
1814 ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1815 TF_ASSIGN_OR_RETURN(
1816 Shape window_output_shape,
1817 InferWindowOutputShape(base_shape, window, lhs.element_type(),
1818 /*allow_negative_padding=*/true));
1819
1820 std::vector<int64> dimensions(num_dims);
1821 dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1822 dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1823
1824 for (int i = 0; i < num_spatial_dims; ++i) {
1825 dimensions[dnums.output_spatial_dimensions(i)] =
1826 window_output_shape.dimensions(i);
1827 }
1828 std::vector<bool> is_dynamic(num_dims);
1829 for (int i = 0; i < num_dims; i++) {
1830 if (lhs.is_dynamic_dimension(i)) {
1831 if (i == dnums.input_batch_dimension()) {
1832 is_dynamic[dnums.output_batch_dimension()] = true;
1833 } else if (i == dnums.input_feature_dimension()) {
1834 // Input feature dimension is a contracting dimension, which does not
1835 // affect the output dimension size. So we need to do nothing.
1836 } else {
1837 for (int64 j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
1838 if (i == dnums.input_spatial_dimensions(j)) {
1839 // i is a spatial dimension, find corresponding output spatial
1840 // dimension.
1841 is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1842 }
1843 }
1844 }
1845 }
1846 if (rhs.is_dynamic_dimension(i)) {
1847 if (i == dnums.kernel_input_feature_dimension()) {
1848 // Kernel feature dimension does not affect the output dimension size.
1849 // So we need to do nothing.
1850 } else if (i == dnums.kernel_output_feature_dimension()) {
1851 return InvalidArgument(
1852 "Dynamic output feature dim on convolution kernel is not "
1853 "supported: rhs shape is %s ",
1854 rhs.ToString());
1855 } else {
1856 for (int64 j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
1857 if (i == dnums.kernel_spatial_dimensions(j)) {
1858 // i is a spatial dimension, find corresponding output spatial
1859 // dimension.
1860 is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1861 }
1862 }
1863 }
1864 }
1865 }
1866 TF_ASSIGN_OR_RETURN(
1867 PrimitiveType type,
1868 MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1869 preferred_element_type));
1870 return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
1871 }
1872
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1873 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1874 const Shape& in, const FftType fft_type,
1875 const absl::Span<const int64> fft_length) {
1876 const int64 fft_rank = fft_length.size();
1877 if (fft_rank < 1 || fft_rank > 3) {
1878 return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1879 }
1880 #define RET_CHECK_RANK(x) \
1881 if (x.dimensions_size() < fft_rank) { \
1882 return InvalidArgument( \
1883 "FFT of rank %d requires input of at least " \
1884 "same rank; got input of rank %d", \
1885 fft_rank, x.dimensions_size()); \
1886 }
1887 switch (fft_type) {
1888 case FFT:
1889 case IFFT:
1890 if (!primitive_util::IsComplexType(in.element_type())) {
1891 return InvalidArgument("%s requires complex input type, found %s.",
1892 FftType_Name(fft_type),
1893 PrimitiveType_Name(in.element_type()));
1894 }
1895 RET_CHECK_RANK(in);
1896 return in;
1897 case RFFT: {
1898 if (in.element_type() != F32 && in.element_type() != F64) {
1899 return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
1900 PrimitiveType_Name(in.element_type()));
1901 }
1902 RET_CHECK_RANK(in);
1903 for (int i = 0; i < fft_rank; i++) {
1904 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1905 fft_length[i]) {
1906 return InvalidArgument(
1907 "RFFT requires innermost dimensions match fft_length but "
1908 "dimension %d is %d and should be %d.",
1909 in.dimensions_size() - fft_rank + i,
1910 in.dimensions(in.dimensions_size() - fft_rank + i),
1911 fft_length[i]);
1912 }
1913 }
1914 Shape result = ShapeUtil::ChangeElementType(
1915 in, in.element_type() == F32 ? C64 : C128);
1916 // Preserve the size of zero-sized dimensions.
1917 if (fft_length[fft_rank - 1] != 0) {
1918 result.set_dimensions(result.dimensions_size() - 1,
1919 fft_length[fft_rank - 1] / 2 + 1);
1920 }
1921 return result;
1922 }
1923 case IRFFT: {
1924 if (!primitive_util::IsComplexType(in.element_type())) {
1925 return InvalidArgument("IRFFT requires complex input type, found %s.",
1926 PrimitiveType_Name(in.element_type()));
1927 }
1928 RET_CHECK_RANK(in);
1929 Shape result = ShapeUtil::ComplexComponentShape(in);
1930 for (int i = 0; i < fft_rank - 1; i++) {
1931 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1932 fft_length[i]) {
1933 return InvalidArgument(
1934 "IRFFT requires all but one innermost dimensions match "
1935 "fft_length, but dimension %d is %d and should be %d.",
1936 in.dimensions_size() - fft_rank + i,
1937 in.dimensions(in.dimensions_size() - fft_rank + i),
1938 fft_length[i]);
1939 }
1940 }
1941 // The size of zero-sized dimensions is preserved.
1942 if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1943 fft_length[fft_rank - 1] != 0) &&
1944 in.dimensions(in.dimensions_size() - 1) !=
1945 fft_length[fft_rank - 1] / 2 + 1) {
1946 return InvalidArgument(
1947 "IRFFT requires innermost dimension matches fft_length/2+1, but "
1948 "dimension %d is %d and should be %d.",
1949 in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1950 fft_length[fft_rank - 1] / 2 + 1);
1951 }
1952 result.set_dimensions(result.dimensions_size() - 1,
1953 fft_length[fft_rank - 1]);
1954 return result;
1955 }
1956 default:
1957 LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1958 }
1959 #undef RET_CHECK_RANK
1960 }
1961
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1962 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1963 const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1964 if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1965 a.element_type() != b.element_type()) {
1966 return InvalidArgument(
1967 "Expected element types in shape to be floating or complex and "
1968 "identical for TriangularSolve; got %s and %s.",
1969 PrimitiveType_Name(a.element_type()),
1970 PrimitiveType_Name(b.element_type()));
1971 }
1972 if (a.rank() < 2) {
1973 return InvalidArgument(
1974 "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1975 a.ToString());
1976 }
1977 if (b.rank() != a.rank()) {
1978 return InvalidArgument(
1979 "Arguments to triangular solve must have equal rank; got %s and %s.",
1980 b.ToString(), a.ToString());
1981 }
1982 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1983 return InvalidArgument(
1984 "The two minor dimensions of 'a' must have equal size, got %s.",
1985 a.ToString());
1986 }
1987 if (a.dimensions(a.rank() - 1) !=
1988 b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1989 return InvalidArgument(
1990 "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1991 "%s",
1992 a.ToString(), b.ToString());
1993 }
1994 absl::Span<const int64> a_batch_dims(a.dimensions());
1995 absl::Span<const int64> b_batch_dims(b.dimensions());
1996 a_batch_dims.remove_suffix(2);
1997 b_batch_dims.remove_suffix(2);
1998 if (a_batch_dims != b_batch_dims) {
1999 return InvalidArgument(
2000 "The leading batch dimensions of the arguments to triangular solve "
2001 "must be equal; got %s and %s.",
2002 b.ToString(), a.ToString());
2003 }
2004 if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
2005 options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
2006 return InvalidArgument(
2007 "Invalid transpose option value for triangular solve (%d).\n",
2008 options.transpose_a());
2009 }
2010 return b;
2011 }
2012
InferCholeskyShape(const Shape & a)2013 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
2014 const Shape& a) {
2015 if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
2016 return InvalidArgument(
2017 "Expected element type in shape to be floating or complex for "
2018 "Cholesky; got %s.",
2019 PrimitiveType_Name(a.element_type()));
2020 }
2021 if (a.rank() < 2) {
2022 return InvalidArgument(
2023 "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
2024 a.ToString());
2025 }
2026 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
2027 return InvalidArgument(
2028 "The two minor dimensions of 'a' must have equal size, got %s.",
2029 a.ToString());
2030 }
2031 return a;
2032 }
2033
InferAllGatherShape(const Shape & operand_shape,int64 all_gather_dimension,int64 shard_count)2034 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
2035 const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) {
2036 TF_RET_CHECK(all_gather_dimension >= 0);
2037 TF_RET_CHECK(all_gather_dimension < operand_shape.rank());
2038 TF_RET_CHECK(shard_count > 0);
2039 auto shape = operand_shape;
2040 shape.set_dimensions(all_gather_dimension,
2041 shard_count * shape.dimensions(all_gather_dimension));
2042 return shape;
2043 }
2044
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2045 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2046 absl::Span<const Shape* const> operand_shapes) {
2047 for (const Shape* operand_shape : operand_shapes) {
2048 TF_RETURN_IF_ERROR(
2049 ExpectArray(*operand_shape, "operand of cross replica sum"));
2050 }
2051 if (operand_shapes.size() == 1) {
2052 return *operand_shapes[0];
2053 }
2054 std::vector<Shape> operand_shape_values;
2055 for (const Shape* operand_shape : operand_shapes) {
2056 operand_shape_values.push_back(*operand_shape);
2057 }
2058 return ShapeUtil::MakeTupleShape(operand_shape_values);
2059 }
2060
InferAllToAllShape(const Shape & shape,int64 split_dimension,int64 concat_dimension,int64 split_count)2061 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2062 const Shape& shape, int64 split_dimension, int64 concat_dimension,
2063 int64 split_count) {
2064 TF_RET_CHECK(split_count > 0);
2065 if (split_dimension >= shape.rank() || split_dimension < 0) {
2066 return InvalidArgument(
2067 "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2068 split_dimension, ShapeUtil::HumanString(shape));
2069 }
2070 if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2071 return InvalidArgument(
2072 "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2073 concat_dimension, ShapeUtil::HumanString(shape));
2074 }
2075 if (shape.dimensions(split_dimension) % split_count != 0) {
2076 return InvalidArgument(
2077 "AllToAll split dimension size %d must be dividable by split_count "
2078 "%d.",
2079 shape.dimensions(split_dimension), split_count);
2080 }
2081 std::vector<int64> new_dimensions(shape.dimensions().begin(),
2082 shape.dimensions().end());
2083 new_dimensions[split_dimension] /= split_count;
2084 new_dimensions[concat_dimension] *= split_count;
2085 return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2086 }
2087
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2088 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2089 absl::Span<const Shape* const> operand_shapes) {
2090 // An Alltoall HLO instruction receives N operands (with the same shape) and
2091 // returns a tuple that contains N array shapes.
2092 TF_RET_CHECK(!operand_shapes.empty());
2093 for (int i = 0; i < operand_shapes.size(); i++) {
2094 if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
2095 *operand_shapes[i])) {
2096 return InvalidArgument(
2097 "HLO all-to-all has operands with different shapes: the 0th "
2098 "operand shape %s, but the %dth operand has shape %s.",
2099 ShapeUtil::HumanString(*operand_shapes[0]), i,
2100 ShapeUtil::HumanString(*operand_shapes[i]));
2101 }
2102 }
2103
2104 return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2105 }
2106
InferCollectivePermuteShape(const Shape & shape)2107 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2108 const Shape& shape) {
2109 TF_RET_CHECK(shape.IsArray());
2110 return shape;
2111 }
2112
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2113 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2114 absl::Span<const Shape* const> arg_shapes,
2115 absl::Span<const int64> dimensions_to_reduce,
2116 const ProgramShape& to_apply) {
2117 if (arg_shapes.empty()) {
2118 return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2119 }
2120 if (arg_shapes.size() % 2) {
2121 return InvalidArgument(
2122 "Reduce must have an even number of arguments, has %lu",
2123 arg_shapes.size());
2124 }
2125 int64 num_reduced_args = arg_shapes.size() / 2;
2126 auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2127 // Check that all of the reduced tensors have the same dimensions. The element
2128 // types may be different.
2129 for (int64 i = 1; i < num_reduced_args; ++i) {
2130 if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2131 return InvalidArgument(
2132 "All reduced tensors must have the same dimension. Tensor 0 has "
2133 "shape %s, Tensor %d has shape %s",
2134 ShapeUtil::HumanString(*reduced_args[0]), i,
2135 ShapeUtil::HumanString(*reduced_args[i]));
2136 }
2137 }
2138 // Check that the dimensions to reduce are in-bounds for the given shape.
2139 // We've already verified all reduced tensors have the same dimensions, so it
2140 // doesn't matter which one we choose.
2141 const Shape& arg = *reduced_args[0];
2142 for (int64 dimension : dimensions_to_reduce) {
2143 if (dimension >= arg.rank() || dimension < 0) {
2144 return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2145 dimension, ShapeUtil::HumanString(arg));
2146 }
2147 }
2148
2149 auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2150 std::vector<PrimitiveType> element_types;
2151 for (const Shape* arg : reduced_args) {
2152 element_types.push_back(arg->element_type());
2153 }
2154 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2155 num_reduced_args));
2156
2157 absl::flat_hash_set<int64> dimensions_to_reduce_set;
2158 for (int64 dim_to_reduce : dimensions_to_reduce) {
2159 if (!dimensions_to_reduce_set.insert(dim_to_reduce).second) {
2160 return InvalidArgument("Duplicate reduction dimension: %d",
2161 dim_to_reduce);
2162 }
2163 }
2164
2165 std::vector<int64> new_dimensions;
2166 std::vector<bool> new_is_dynamic;
2167 for (int i = 0; i < arg.rank(); ++i) {
2168 if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2169 new_dimensions.push_back(arg.dimensions(i));
2170 new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2171 }
2172 }
2173
2174 if (ShapeUtil::IsScalar(to_apply.result())) {
2175 return ShapeUtil::MakeShape(to_apply.result().element_type(),
2176 new_dimensions, new_is_dynamic);
2177 } else {
2178 std::vector<Shape> result_subshapes;
2179 for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2180 result_subshapes.push_back(ShapeUtil::MakeShape(
2181 subshape.element_type(), new_dimensions, new_is_dynamic));
2182 }
2183 return ShapeUtil::MakeTupleShape(result_subshapes);
2184 }
2185 }
2186
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2187 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2188 const Shape& operand_shape, const Shape& init_value_shape,
2189 const Window& window, const ProgramShape& to_apply_shape) {
2190 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2191 {operand_shape.element_type()},
2192 /*inputs=*/1));
2193 return InferReduceWindowShape(operand_shape, init_value_shape, window);
2194 }
2195
InferReduceWindowShape(absl::Span<const Shape * const> operands,absl::Span<const Shape * const> init_values,const Window & window,const ProgramShape & to_apply_shape)2196 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2197 absl::Span<const Shape* const> operands,
2198 absl::Span<const Shape* const> init_values, const Window& window,
2199 const ProgramShape& to_apply_shape) {
2200 auto number_of_input = operands.size();
2201 // Check that all of the reduced tensors have the same dimensions. The element
2202 // types may be different.
2203 for (int64 i = 1; i < number_of_input; ++i) {
2204 if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
2205 return InvalidArgument(
2206 "All reduced tensors must have the same dimension. Tensor 0 has "
2207 "shape %s, Tensor %d has shape %s",
2208 ShapeUtil::HumanString(*operands[0]), i,
2209 ShapeUtil::HumanString(*operands[i]));
2210 }
2211 }
2212 std::vector<PrimitiveType> operand_element_type_vec;
2213 for (const Shape* s : operands) {
2214 operand_element_type_vec.push_back(s->element_type());
2215 }
2216 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
2217 operand_element_type_vec,
2218 /*inputs=*/number_of_input));
2219 std::vector<Shape> output_shape_vec;
2220 for (int i = 0; i < operands.size(); ++i) {
2221 TF_ASSIGN_OR_RETURN(
2222 auto cur_output_shape,
2223 InferReduceWindowShape(*operands[i], *init_values[i], window));
2224 output_shape_vec.push_back(cur_output_shape);
2225 }
2226 if (ShapeUtil::IsScalar(to_apply_shape.result())) {
2227 CHECK_EQ(output_shape_vec.size(), 1);
2228 return output_shape_vec[0];
2229 } else {
2230 return ShapeUtil::MakeTupleShape(output_shape_vec);
2231 }
2232 }
2233
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2234 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2235 const Shape& operand_shape, const Shape& init_value_shape,
2236 const Window& window) {
2237 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2238 return InferWindowOutputShape(operand_shape, window,
2239 init_value_shape.element_type(),
2240 /*allow_negative_padding=*/false);
2241 }
2242
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2243 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2244 const Shape& operand_shape, const ProgramShape& select_shape,
2245 const Window& window, const Shape& source_shape,
2246 const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2247 TF_RETURN_IF_ERROR(
2248 ExpectArray(operand_shape, "operand of select-and-scatter"));
2249
2250 // Check if the select function has a proper shape of (T,T) -> PRED.
2251 if (select_shape.parameters_size() != 2) {
2252 return InvalidArgument(
2253 "Select function must take 2 parameters, but "
2254 "takes %d parameter(s).",
2255 select_shape.parameters_size());
2256 }
2257 const Shape& select_result_shape = select_shape.result();
2258 if (!ShapeUtil::Compatible(select_result_shape,
2259 ShapeUtil::MakeShape(PRED, {}))) {
2260 return InvalidArgument("Select function must have rank-0 PRED result.");
2261 }
2262 const Shape& operand_element_shape =
2263 ShapeUtil::MakeShape(operand_shape.element_type(), {});
2264 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2265 select_shape.parameters(0))) {
2266 return InvalidArgument(
2267 "Select function's first parameter shape currently must "
2268 "match the operand element shape, but got %s vs %s.",
2269 ShapeUtil::HumanString(select_shape.parameters(0)),
2270 ShapeUtil::HumanString(operand_element_shape));
2271 }
2272 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2273 select_shape.parameters(1))) {
2274 return InvalidArgument(
2275 "Select function's second parameter shape currently must "
2276 "match the operand element shape, but got %s vs %s.",
2277 ShapeUtil::HumanString(select_shape.parameters(1)),
2278 ShapeUtil::HumanString(operand_element_shape));
2279 }
2280
2281 // Check if the scatter function has a proper shape as a reduction.
2282 TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2283 {source_shape.element_type()},
2284 /*inputs=*/1));
2285
2286 // Check if the result shape of window operation matches the source shape.
2287 TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2288 InferWindowOutputShape(operand_shape, window,
2289 operand_shape.element_type(),
2290 /*allow_negative_padding=*/false));
2291 if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2292 window_result_shape)) {
2293 return InvalidArgument(
2294 "Source shape does not match the shape of window-reduced operand: "
2295 "source(%s), window-reduced operand(%s).",
2296 ShapeUtil::HumanString(source_shape),
2297 ShapeUtil::HumanString(window_result_shape));
2298 }
2299
2300 return operand_shape;
2301 }
2302
InferGetDimensionSizeShape(const Shape & shape,int64 dimension)2303 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2304 const Shape& shape, int64 dimension) {
2305 if (dimension < 0 || dimension >= shape.rank()) {
2306 return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2307 dimension);
2308 }
2309
2310 // TODO(b/119580730): Remove this restriction when very large dimension size
2311 // is needed.
2312 if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2313 return InvalidArgument(
2314 "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2315 "INT_MAX limit.",
2316 ShapeUtil::HumanString(shape), dimension);
2317 }
2318
2319 return ShapeUtil::MakeShape(S32, {});
2320 }
2321
InferSetDimensionSizeShape(const Shape & shape,const Shape & val_shape,int64 dimension)2322 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2323 const Shape& shape, const Shape& val_shape, int64 dimension) {
2324 if (dimension < 0 || dimension >= shape.rank()) {
2325 return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2326 dimension);
2327 }
2328
2329 if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
2330 return InvalidArgument(
2331 "SetDimensionSize's value has to be S32 scalar, got %s",
2332 val_shape.ToString());
2333 }
2334 // TODO(b/119580730): Remove this restriction when very large dimension size
2335 // is needed.
2336 if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2337 return InvalidArgument(
2338 "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2339 "INT_MAX limit.",
2340 ShapeUtil::HumanString(shape), dimension);
2341 }
2342
2343 Shape result = shape;
2344 result.set_dynamic_dimension(dimension, true);
2345 return result;
2346 }
2347
InferWindowFromDimensions(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation)2348 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2349 absl::Span<const int64> window_dimensions,
2350 absl::Span<const int64> window_strides,
2351 absl::Span<const std::pair<int64, int64>> padding,
2352 absl::Span<const int64> lhs_dilation,
2353 absl::Span<const int64> rhs_dilation) {
2354 const auto verify_size = [&](const size_t x, const char* x_name) {
2355 if (x == 0 || x == window_dimensions.size()) {
2356 return Status::OK();
2357 } else {
2358 return InvalidArgument(
2359 "%s", absl::StrCat(
2360 "Window has different number of window dimensions than of ",
2361 x_name,
2362 "\nNumber of window dimensions: ", window_dimensions.size(),
2363 "\nNumber of ", x_name, ": ", x, "\n"));
2364 }
2365 };
2366 TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2367 TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2368 TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2369 TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2370
2371 Window window;
2372 for (size_t i = 0; i < window_dimensions.size(); i++) {
2373 auto dim = window.add_dimensions();
2374 dim->set_size(window_dimensions[i]);
2375 if (!window_strides.empty()) {
2376 dim->set_stride(window_strides[i]);
2377 } else {
2378 dim->set_stride(1);
2379 }
2380 if (!padding.empty()) {
2381 dim->set_padding_low(padding[i].first);
2382 dim->set_padding_high(padding[i].second);
2383 } else {
2384 dim->set_padding_low(0);
2385 dim->set_padding_high(0);
2386 }
2387 if (!lhs_dilation.empty()) {
2388 dim->set_base_dilation(lhs_dilation[i]);
2389 } else {
2390 dim->set_base_dilation(1);
2391 }
2392 if (!rhs_dilation.empty()) {
2393 dim->set_window_dilation(rhs_dilation[i]);
2394 } else {
2395 dim->set_window_dilation(1);
2396 }
2397 dim->set_window_reversal(false);
2398 }
2399 return window;
2400 }
2401
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2402 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2403 const Shape& arg, absl::Span<const int64> starts,
2404 absl::Span<const int64> limits, absl::Span<const int64> strides) {
2405 auto error = [&](const string& message) {
2406 return InvalidArgument(
2407 "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2408 "{%s}; strides: {%s}.",
2409 message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2410 StrJoin(limits, ","), StrJoin(strides, ","));
2411 };
2412 TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2413 VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2414 ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2415 StrJoin(limits, ", "));
2416
2417 if (starts.size() != limits.size()) {
2418 return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2419 starts.size(), limits.size()));
2420 }
2421
2422 if (starts.size() != strides.size()) {
2423 return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2424 starts.size(), strides.size()));
2425 }
2426
2427 if (starts.size() != arg.rank()) {
2428 return InvalidArgument(
2429 "Slice index count does not match argument rank: %u vs %d.",
2430 starts.size(), arg.rank());
2431 }
2432
2433 std::vector<int64> sizes;
2434 for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
2435 int64 start_index = starts[dimension];
2436 int64 limit_index = limits[dimension];
2437 int64 stride = strides[dimension];
2438 if (start_index < 0) {
2439 return InvalidArgument("Negative start index to slice: %d.", start_index);
2440 }
2441 if (limit_index > arg.dimensions(dimension)) {
2442 return error(
2443 StrFormat("limit index (%d) must be less than or equal to dimension "
2444 "size (%d)",
2445 limit_index, arg.dimensions(dimension)));
2446 }
2447 VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2448 VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2449 if (start_index > limit_index) {
2450 return error(
2451 StrFormat("limit index (%d) must be greater or equal to "
2452 "start index (%d) in slice with positive stride",
2453 limit_index, start_index));
2454 }
2455 if (stride <= 0) {
2456 return InvalidArgument("Stride (%d) must be positive.", stride);
2457 }
2458 sizes.push_back((limit_index - start_index + stride - 1) / stride);
2459 }
2460
2461 std::vector<bool> is_dynamic(arg.rank());
2462 for (int64 i = 0; i < arg.dimensions_size(); ++i) {
2463 // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2464 if (sizes[i] == 1) {
2465 continue;
2466 }
2467 is_dynamic[i] = arg.is_dynamic_dimension(i);
2468 }
2469
2470 return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2471 }
2472
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2473 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2474 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2475 absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2476 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2477 auto number_of_indices = start_index_shapes.size();
2478 // TODO(b/118437727): Remove this path.
2479 if (!allow_scalar_indices ||
2480 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2481 if (number_of_indices != 1) {
2482 return InvalidArgument(
2483 "Dynamic slice should have exactly 1 index operand, has %d.",
2484 number_of_indices);
2485 }
2486
2487 const Shape& start_indices_shape = start_index_shapes[0];
2488 VLOG(2) << StrFormat(
2489 "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2490 ShapeUtil::HumanString(operand_shape),
2491 ShapeUtil::HumanString(start_indices_shape),
2492 StrJoin(slice_sizes, ", "));
2493
2494 TF_RETURN_IF_ERROR(
2495 ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2496
2497 if (start_indices_shape.rank() != 1) {
2498 return InvalidArgument(
2499 "Dynamic slice start indices of rank %d must be rank1.",
2500 start_indices_shape.rank());
2501 }
2502
2503 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2504 return InvalidArgument(
2505 "Dynamic slice start indices must be of integral type.");
2506 }
2507
2508 const int64 start_num_dims = start_indices_shape.dimensions(0);
2509 if (operand_shape.rank() != start_num_dims) {
2510 return InvalidArgument(
2511 "Dynamic slice start number of dimensions %d (%s) must match rank "
2512 "%d of slice input (%s).",
2513 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2514 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2515 }
2516 } else {
2517 VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2518 ShapeUtil::HumanString(operand_shape),
2519 StrJoin(slice_sizes, ", "));
2520
2521 if (operand_shape.rank() != number_of_indices) {
2522 return InvalidArgument(
2523 "Dynamic slice start number of dimensions %d must match rank "
2524 "%d of slice input (%s).",
2525 number_of_indices, operand_shape.rank(),
2526 ShapeUtil::HumanString(operand_shape));
2527 }
2528
2529 if (number_of_indices > 0) {
2530 const Shape& first_index_shape = start_index_shapes[0];
2531 if (!ShapeUtil::IsScalar(first_index_shape)) {
2532 return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2533 ShapeUtil::HumanString(first_index_shape));
2534 }
2535 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2536 return InvalidArgument(
2537 "Dynamic slice start indices must be of integral type.");
2538 }
2539 for (const Shape& index_shape : start_index_shapes) {
2540 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2541 return InvalidArgument(
2542 "Dynamic slice start indices must all have the same shape, got "
2543 "mismatching indices with shapes %s and %s.",
2544 ShapeUtil::HumanString(first_index_shape),
2545 ShapeUtil::HumanString(index_shape));
2546 }
2547 }
2548 }
2549 }
2550
2551 if (slice_sizes.size() != operand_shape.rank()) {
2552 return InvalidArgument(
2553 "Dynamic slice index count does not match argument rank: %u vs %d.",
2554 slice_sizes.size(), operand_shape.rank());
2555 }
2556
2557 for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2558 const int64 input_dim_size = operand_shape.dimensions(dim);
2559 const int64 slice_dim_size = slice_sizes[dim];
2560 if (slice_dim_size < 0) {
2561 return InvalidArgument("Negative size index to dynamic slice: %d.",
2562 slice_dim_size);
2563 }
2564 if (slice_dim_size > input_dim_size) {
2565 return InvalidArgument(
2566 "Slice dim size %d greater than dynamic slice dimension: %d.",
2567 slice_dim_size, input_dim_size);
2568 }
2569 VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2570 }
2571
2572 return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2573 }
2574
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2575 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2576 const Shape& operand_shape, const Shape& update_shape,
2577 absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2578 TF_RETURN_IF_ERROR(
2579 ExpectArray(operand_shape, "operand of dynamic update slice"));
2580 TF_RETURN_IF_ERROR(
2581 ExpectArray(update_shape, "update of dynamic update slice"));
2582
2583 auto number_of_indices = start_index_shapes.size();
2584 // TODO(b/118437727): Remove this path.
2585 if (!allow_scalar_indices ||
2586 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2587 if (number_of_indices != 1) {
2588 return InvalidArgument(
2589 "Dynamic update slice should have exactly 1 index operand, has %d.",
2590 number_of_indices);
2591 }
2592 const Shape& start_indices_shape = start_index_shapes[0];
2593 TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2594 "start indices of dynamic update slice"));
2595
2596 VLOG(2) << StrFormat(
2597 "updating slice of shape %s at dynamic start_indices %s with update "
2598 "shape %s",
2599 ShapeUtil::HumanString(operand_shape),
2600 ShapeUtil::HumanString(start_indices_shape),
2601 ShapeUtil::HumanString(update_shape));
2602
2603 if (start_indices_shape.rank() != 1) {
2604 return InvalidArgument(
2605 "Dynamic update slice start indices of rank %d must be rank1.",
2606 start_indices_shape.rank());
2607 }
2608
2609 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2610 return InvalidArgument(
2611 "Dynamic update slice start indices must be of integral type.");
2612 }
2613
2614 const int64 start_num_dims = start_indices_shape.dimensions(0);
2615 if (operand_shape.rank() != start_num_dims) {
2616 return InvalidArgument(
2617 "Dynamic update slice start number of dimensions %d (%s) must match "
2618 "rank %d of slice input (%s).",
2619 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2620 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2621 }
2622 } else {
2623 VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2624 ShapeUtil::HumanString(operand_shape),
2625 ShapeUtil::HumanString(update_shape));
2626
2627 if (operand_shape.rank() != number_of_indices) {
2628 return InvalidArgument(
2629 "Dynamic update slice start number of dimensions %d must match "
2630 "rank %d of slice input (%s).",
2631 number_of_indices, operand_shape.rank(),
2632 ShapeUtil::HumanString(operand_shape));
2633 }
2634
2635 if (number_of_indices > 0) {
2636 const Shape& first_index_shape = start_index_shapes[0];
2637 if (!ShapeUtil::IsScalar(first_index_shape)) {
2638 return InvalidArgument(
2639 "Dynamic update slice indices must be scalar, not %s.",
2640 ShapeUtil::HumanString(first_index_shape));
2641 }
2642 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2643 return InvalidArgument(
2644 "Dynamic update slice start indices must be of integral type.");
2645 }
2646 for (const Shape& index_shape : start_index_shapes) {
2647 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2648 return InvalidArgument(
2649 "Dynamic update slice start indices must all have the same "
2650 "shape, got mismatching indices with shapes %s and %s.",
2651 ShapeUtil::HumanString(first_index_shape),
2652 ShapeUtil::HumanString(index_shape));
2653 }
2654 }
2655 }
2656 }
2657
2658 if (update_shape.rank() != operand_shape.rank()) {
2659 return InvalidArgument(
2660 "Dynamic update slice update rank does not match argument rank: "
2661 "%d vs %d.",
2662 update_shape.rank(), operand_shape.rank());
2663 }
2664
2665 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2666 update_shape)) {
2667 return InvalidArgument(
2668 "Dynamic update slice update element type does not match argument. "
2669 "operand.element_type: %s vs update.element_type: %s.",
2670 PrimitiveType_Name(operand_shape.element_type()),
2671 PrimitiveType_Name(update_shape.element_type()));
2672 }
2673
2674 for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
2675 const int64 input_dim_size = operand_shape.dimensions(dim);
2676 const int64 update_dim_size = update_shape.dimensions(dim);
2677 if (update_dim_size < 0) {
2678 return InvalidArgument(
2679 "Size index %d to dynamic update slice must be >= 0.",
2680 update_dim_size);
2681 }
2682 if (update_dim_size > input_dim_size) {
2683 return InvalidArgument(
2684 "Update dim size %d greater than dynamic slice dimension: %d.",
2685 update_dim_size, input_dim_size);
2686 }
2687 VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2688 }
2689
2690 auto result_shape = operand_shape;
2691
2692 // If any of the operand shape is dynamic, the result dimension is also
2693 // dynamic.
2694 // If update shape is dynamic, only propagate dynamic dimension to result if
2695 // the update is a full update (update_shape[i] == operand_shape[i]).
2696 for (int64 i = 0; i < update_shape.rank(); ++i) {
2697 if (operand_shape.is_dynamic_dimension(i)) {
2698 result_shape.set_dynamic_dimension(i, true);
2699 }
2700
2701 if (update_shape.is_dynamic_dimension(i) &&
2702 update_shape.dimensions(i) == operand_shape.dimensions(i)) {
2703 // When update/replace a full dimension, propagate dynamic dimension to
2704 // the result.
2705 result_shape.set_dynamic_dimension(i, true);
2706 }
2707 }
2708
2709 return result_shape;
2710 }
2711
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2712 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2713 const Shape& operand_shape, absl::Span<const int64> dimensions) {
2714 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2715 if (!AllUnique(dimensions)) {
2716 return InvalidArgument("a dimension number is duplicated in reverse");
2717 }
2718 for (int64 dimension : dimensions) {
2719 if (dimension >= operand_shape.rank() || dimension < 0) {
2720 return InvalidArgument(
2721 "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2722 dimension, ShapeUtil::HumanString(operand_shape));
2723 }
2724 }
2725 return operand_shape;
2726 }
2727
InferGetTupleElementShape(const Shape & arg,int64 index)2728 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2729 const Shape& arg, int64 index) {
2730 if (!arg.IsTuple()) {
2731 return InvalidArgument(
2732 "Cannot infer shape: attempting to index into non-tuple: %s.",
2733 ShapeUtil::HumanString(arg));
2734 }
2735
2736 if (index < 0 || index >= arg.tuple_shapes_size()) {
2737 return InvalidArgument(
2738 "Cannot infer shape: attempt to index out of tuple bounds: %d "
2739 ">= %d in shape %s.",
2740 index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2741 }
2742
2743 return arg.tuple_shapes(index);
2744 }
2745
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2746 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2747 const ProgramShape& condition, const ProgramShape& body,
2748 const Shape& init) {
2749 // Check the number of parameters for given computations.
2750 if (condition.parameters_size() != 1) {
2751 return InvalidArgument("Condition must take 1 arguments; got %d.",
2752 condition.parameters_size());
2753 }
2754 if (body.parameters_size() != 1) {
2755 return InvalidArgument("Body must take 1 arguments; got %d.",
2756 body.parameters_size());
2757 }
2758
2759 auto shape_string = [&]() {
2760 return StrFormat(
2761 "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2762 ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2763 };
2764
2765 // Check the shapes of computation parameters and return types.
2766 if (!ShapeUtil::Compatible(condition.result(),
2767 ShapeUtil::MakeShape(PRED, {}))) {
2768 return InvalidArgument("Condition must return a boolean; got %s.",
2769 shape_string());
2770 }
2771 if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2772 !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2773 !ShapeUtil::Compatible(body.result(), init)) {
2774 return InvalidArgument(
2775 "The parameter of condition and body, the result of the body, and init "
2776 "must all have the same shape; got %s.",
2777 shape_string());
2778 }
2779
2780 return init;
2781 }
2782
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2783 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2784 const Shape& branch_index,
2785 absl::Span<const ProgramShape> branch_computations,
2786 absl::Span<const Shape> branch_operands) {
2787 if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2788 !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2789 return InvalidArgument("branch_index must be bool or int32; got %s.",
2790 ShapeUtil::HumanString(branch_index));
2791 }
2792 if (branch_index.element_type() == PRED) {
2793 TF_RET_CHECK(2 == branch_computations.size());
2794 } else {
2795 TF_RET_CHECK(!branch_computations.empty());
2796 }
2797 TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2798
2799 for (int j = 0; j < branch_computations.size(); ++j) {
2800 if (branch_computations[j].parameters_size() != 1) {
2801 return InvalidArgument(
2802 "branch computation %d must take 1 argument; got %d.", j,
2803 branch_computations[j].parameters_size());
2804 }
2805 if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2806 branch_operands[j])) {
2807 auto shape_string = [&]() {
2808 return StrFormat("operand: %s; computation: %s",
2809 ShapeUtil::HumanString(branch_operands[j]),
2810 ShapeUtil::HumanString(branch_computations[j]));
2811 };
2812 return InvalidArgument(
2813 "branch operand %d must match the shape of the only parameter of "
2814 "branch computation %d: got %s.",
2815 j, j, shape_string());
2816 }
2817
2818 if (!ShapeUtil::Compatible(branch_computations[0].result(),
2819 branch_computations[j].result())) {
2820 auto shape_string = [&]() {
2821 return StrFormat(
2822 "branch 0 computation result: %s; branch %d computation result: %s",
2823 ShapeUtil::HumanString(branch_computations[0].result()), j,
2824 ShapeUtil::HumanString(branch_computations[j].result()));
2825 };
2826 return InvalidArgument(
2827 "the result of branch 0 computation and branch %d computation must "
2828 "have the same shape: got %s.",
2829 j, shape_string());
2830 }
2831 }
2832 return branch_computations[0].result();
2833 }
2834
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2835 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2836 const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2837 TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2838 for (int64 size : broadcast_sizes) {
2839 if (size < 0) {
2840 return InvalidArgument("Broadcast with negative dimension size %d.",
2841 size);
2842 }
2843 }
2844
2845 std::vector<int64> dimensions(operand.dimensions_size() +
2846 broadcast_sizes.size());
2847 std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2848 std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2849 dimensions.begin() + broadcast_sizes.size());
2850
2851 Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2852 for (int64 i = 0; i < operand.dimensions_size(); ++i) {
2853 result.set_dynamic_dimension(broadcast_sizes.size() + i,
2854 operand.is_dynamic_dimension(i));
2855 }
2856 return result;
2857 }
2858
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2859 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2860 const Shape& operand_shape, const Shape& output_shape,
2861 absl::Span<const int64> broadcast_dimensions) {
2862 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
2863 TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
2864 const int64 operand_rank = operand_shape.rank();
2865 const int64 output_rank = output_shape.rank();
2866 if (operand_rank > output_rank) {
2867 return InvalidArgument(
2868 "InDim style broadcast must be to an equal or higher ranked shape; "
2869 "operand rank: %lld; output rank: %lld",
2870 operand_rank, output_rank);
2871 }
2872 if (operand_rank != broadcast_dimensions.size()) {
2873 return InvalidArgument(
2874 "Size of broadcast_dimensions has to match operand's rank; operand "
2875 "rank: %lld, size of broadcast_dimensions %u.",
2876 operand_rank, broadcast_dimensions.size());
2877 }
2878 for (int64 i = 0; i < operand_rank; i++) {
2879 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
2880 return InvalidArgument("Broadcast dimension %lld is out of bound",
2881 broadcast_dimensions[i]);
2882 }
2883 if (operand_shape.dimensions(i) !=
2884 output_shape.dimensions(broadcast_dimensions[i]) &&
2885 operand_shape.dimensions(i) != 1) {
2886 return InvalidArgument(
2887 "Input dimension should be either 1 or equal to the output dimension "
2888 "it is broadcasting into; the %lldth operand dimension is %lld, the "
2889 "%lldth output dimension is %lld.",
2890 i, operand_shape.dimensions(i), broadcast_dimensions[i],
2891 output_shape.dimensions(broadcast_dimensions[i]));
2892 }
2893 if (operand_shape.is_dynamic_dimension(i) !=
2894 output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
2895 return InvalidArgument(
2896 "Broadcast input and output dynamism mismatch: %s and %s",
2897 operand_shape.ToString(), output_shape.ToString());
2898 }
2899 // Make sure the broadcast dimensions are listed in a strictly increasing
2900 // order.
2901 if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
2902 return InvalidArgument(
2903 "Broadcast dimensions order is wrong: %d comes after %d.",
2904 broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
2905 }
2906 }
2907
2908 return output_shape;
2909 }
2910
InferDynamicReshapeShape(const Shape & operand,absl::Span<const Shape * const> dim_size_shapes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)2911 /* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
2912 const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
2913 absl::Span<const int64> new_size_bounds,
2914 const std::vector<bool>& dims_are_dynamic) {
2915 if (new_size_bounds.size() != dims_are_dynamic.size()) {
2916 return InvalidArgument(
2917 "DynamicReshape has to have the same number of elements in new_sizes "
2918 "(%d) and dims_are_dynamic (%d)",
2919 new_size_bounds.size(), dims_are_dynamic.size());
2920 }
2921
2922 for (const Shape* dim_size_shape : dim_size_shapes) {
2923 if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
2924 return InvalidArgument(
2925 "DynamicReshape's dim size has to be scalar S32, got (%s): ",
2926 dim_size_shape->ToString());
2927 }
2928 }
2929
2930 Shape inferred_shape = ShapeUtil::MakeShape(
2931 operand.element_type(), new_size_bounds, dims_are_dynamic);
2932 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2933 return InvalidArgument(
2934 "Reshape operation has mismatched element counts: from=%d (%s) "
2935 "to=%d (%s).",
2936 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2937 ShapeUtil::ElementsIn(inferred_shape),
2938 ShapeUtil::HumanString(inferred_shape));
2939 }
2940 return inferred_shape;
2941 }
2942
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)2943 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2944 const Shape& operand, absl::Span<const int64> dimensions,
2945 absl::Span<const int64> new_sizes, int64 inferred_dimension) {
2946 TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
2947
2948 Shape inferred_shape =
2949 ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2950 VLOG(3) << "Reshape inferred shape: "
2951 << ShapeUtil::HumanString(inferred_shape);
2952
2953 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2954 return InvalidArgument(
2955 "Reshape operation has mismatched element counts: from=%d (%s) "
2956 "to=%d (%s).",
2957 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2958 ShapeUtil::ElementsIn(inferred_shape),
2959 ShapeUtil::HumanString(inferred_shape));
2960 }
2961
2962 std::vector<int64> indices(operand.rank());
2963 std::iota(indices.begin(), indices.end(), 0);
2964 if (dimensions.size() != operand.rank() ||
2965 !std::is_permutation(dimensions.begin(), dimensions.end(),
2966 indices.begin())) {
2967 return InvalidArgument(
2968 "Reshape dimensions [%s] are not a permutation of the operand "
2969 "dimensions (operand shape is %s).",
2970 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2971 }
2972
2973 // Propagate dynamic dimension.
2974 auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
2975 for (int64 input_dim = 0; input_dim < operand.rank(); ++input_dim) {
2976 if (!operand.is_dynamic_dimension(input_dim)) {
2977 continue;
2978 }
2979
2980 string reshape_debug_str = absl::StrFormat(
2981 "output: %s, input: %s, input_dim: "
2982 "%lld",
2983 ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
2984 input_dim);
2985
2986 int64 input_dim_start = -1;
2987 int64 input_dim_end = -1;
2988 int64 output_dim_start = -1;
2989 int64 output_dim_end = -1;
2990 // Find common_factors that the input_dim belongs to.
2991 for (int64 i = 0; i < common_factors.size() - 1; ++i) {
2992 auto start = common_factors[i];
2993 auto end = common_factors[i + 1];
2994 if (input_dim >= start.first && input_dim < end.first) {
2995 input_dim_start = start.first;
2996 input_dim_end = end.first;
2997 output_dim_start = start.second;
2998 output_dim_end = end.second;
2999 break;
3000 }
3001 }
3002 if ((input_dim_end - input_dim_start) > 1 &&
3003 (output_dim_end - output_dim_start) > 1) {
3004 // We don't support the case when a dynamic dimension is both combined
3005 // with and splitted into other dimensions:
3006 //
3007 // [x, yz]
3008 // | Reshape
3009 // [xy, z]
3010 return Unimplemented(
3011 "Dynamic input dimension to reshape that is both splitted and "
3012 "combined is not supported: %s",
3013 reshape_debug_str);
3014 }
3015
3016 for (auto common_factor : common_factors) {
3017 //
3018 // For reshapes like:
3019 // [<=5]
3020 // | Reshape
3021 // [1, 5]
3022 //
3023 // The input dynamic dimension can go into either dynamic dimensions.
3024 // However, the return value of common factors only returns
3025 // input: 5
3026 // output: 5
3027 //
3028 // We need to expand common factor to include degenerated output
3029 // dimensions:
3030 // input: 5
3031 // output: 1, 5
3032 //
3033 // such that in the logic later on we can consider both dimensions as
3034 // candidate.
3035 if (common_factor.first == input_dim_start) {
3036 output_dim_start = std::min(output_dim_start, common_factor.second);
3037 }
3038 if (common_factor.first == input_dim_end) {
3039 output_dim_end = std::max(output_dim_end, common_factor.second);
3040 }
3041 }
3042
3043 // Calculate output dynamic reshape dimension.
3044 int64 output_dynamic_dimension = -1;
3045
3046 if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
3047 // If dynamic dimension is size 1, it can only be most-major or
3048 // most-minor.
3049 if (input_dim == 0) {
3050 output_dynamic_dimension = 0;
3051 }
3052 if (input_dim == operand.rank() - 1) {
3053 output_dynamic_dimension = new_sizes.size() - 1;
3054 }
3055
3056 if (output_dynamic_dimension == -1) {
3057 return Unimplemented(
3058 "Dynamic degenerated dimension that's not most-minor nor "
3059 "most-major is not supported: %s",
3060 reshape_debug_str);
3061 }
3062 }
3063
3064 if (output_dynamic_dimension == -1 &&
3065 output_dim_end - output_dim_start == 1) {
3066 // Only one possible output dimension.
3067 output_dynamic_dimension = output_dim_start;
3068 }
3069 if (output_dynamic_dimension == -1 &&
3070 output_dim_end - output_dim_start > 1) {
3071 // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
3072 output_dynamic_dimension = inferred_dimension;
3073 }
3074
3075 if (output_dynamic_dimension != -1) {
3076 // TODO(yunxing): Turn this into a CHECK.
3077 inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
3078 } else {
3079 std::vector<int64> output_non_degenerated;
3080 for (int64 i = output_dim_start; i < output_dim_end; ++i) {
3081 if (new_sizes[i] != 1) {
3082 output_non_degenerated.push_back(i);
3083 }
3084 }
3085 if (output_non_degenerated.size() == 1) {
3086 inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
3087 }
3088 }
3089 }
3090
3091 return inferred_shape;
3092 }
3093
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)3094 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
3095 const Shape& operand, absl::Span<const int64> dimensions) {
3096 TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
3097
3098 if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
3099 return InvalidArgument(
3100 "Transpose dimensions [%s] are not a permutation of the operand "
3101 "dimensions (operand shape is %s).",
3102 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3103 }
3104
3105 // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
3106 // we need output[i]=input[dimensions[i]] which is
3107 // Permute(Inverse(dimensions),input).
3108 return ShapeUtil::PermuteDimensions(dimensions, operand);
3109 }
3110
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)3111 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
3112 const Shape& min, const Shape& operand, const Shape& max) {
3113 TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
3114 TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
3115 TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
3116
3117 if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
3118 !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
3119 return InvalidArgument(
3120 "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
3121 ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
3122 }
3123 return operand;
3124 }
3125
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3126 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
3127 const Shape& pred, const Shape& on_true, const Shape& on_false) {
3128 TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
3129 TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
3130 TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
3131
3132 if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
3133 return InvalidArgument(
3134 "Operands to select must be the same shape; got %s and %s.",
3135 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3136 }
3137 if (pred.element_type() != PRED) {
3138 return InvalidArgument(
3139 "Select's pred operand must have PRED element type; got %s.",
3140 ShapeUtil::HumanString(pred));
3141 }
3142 if (!Shape::Equal()
3143 .IgnoreElementType()
3144 .IgnoreLayout()
3145 .IgnoreDynamicDimension()(pred, on_true)) {
3146 return InvalidArgument(
3147 "Operands to select and predicate must be the same shape; got %s and "
3148 "%s.",
3149 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3150 }
3151
3152 return ShapeUtil::ChangeElementType(
3153 pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3154 }
3155
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3156 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
3157 const Shape& pred, const Shape& on_true, const Shape& on_false) {
3158 // Select only defines the top-level buffer, so if it's a tuple, the two
3159 // input must match exactly.
3160 if (!ShapeUtil::Compatible(on_true, on_false)) {
3161 return InvalidArgument(
3162 "Operands to tuple-select must be the same shape; got %s and %s.",
3163 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3164 }
3165 if (pred.element_type() != PRED) {
3166 return InvalidArgument(
3167 "TupleSelect's pred operand must have PRED element type; got %s.",
3168 ShapeUtil::HumanString(pred));
3169 }
3170 if (!ShapeUtil::IsScalar(pred)) {
3171 return InvalidArgument(
3172 "TupleSelect operation with non-scalar predicate: %s.",
3173 ShapeUtil::HumanString(pred));
3174 }
3175 return on_true;
3176 }
3177
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3178 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3179 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3180 // The applied function's arity equals the number of arguments.
3181 if (arg_shapes.size() != to_apply.parameters_size()) {
3182 string computation_signature = ShapeUtil::HumanString(to_apply);
3183 string argument_shapes =
3184 StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
3185 absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3186 });
3187 return InvalidArgument(
3188 "Call applied function arity must match number of arguments; got: "
3189 "arity: %d, arguments: %u; computation signature: %s; argument "
3190 "shapes: [%s].",
3191 to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3192 argument_shapes);
3193 }
3194
3195 // All arguments must be compatible with the program shape.
3196 for (int i = 0; i < arg_shapes.size(); ++i) {
3197 const Shape& arg_shape = *arg_shapes[i];
3198 const Shape& param_shape = to_apply.parameters(i);
3199 if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3200 return InvalidArgument(
3201 "Call parameter must match argument; got parameter %d shape: %s, "
3202 "argument shape: %s.",
3203 i, ShapeUtil::HumanString(param_shape),
3204 ShapeUtil::HumanString(arg_shape));
3205 }
3206 }
3207
3208 return to_apply.result();
3209 }
3210
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3211 static Status ValidateGatherDimensionNumbers(
3212 const Shape& input_shape, absl::Span<const int64> start_indices_shape,
3213 const GatherDimensionNumbers& dim_numbers) {
3214 if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3215 return InvalidArgument(
3216 "Output window dimensions in gather op must be ascending; got: %s.",
3217 StrJoin(dim_numbers.offset_dims(), ", "));
3218 }
3219
3220 if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3221 dim_numbers.offset_dims().end()) {
3222 return InvalidArgument(
3223 "Output window dimensions in gather op must not repeat; got: %s.",
3224 StrJoin(dim_numbers.offset_dims(), ", "));
3225 }
3226
3227 const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
3228 const int64 output_shape_rank =
3229 output_offset_dim_count + start_indices_shape.size() - 1;
3230
3231 for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3232 int64 offset_dim = dim_numbers.offset_dims(i);
3233 if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3234 return InvalidArgument(
3235 "Offset dimension %d in gather op is out of bounds; got %d, but "
3236 "should "
3237 "have been in [0,%d).",
3238 i, offset_dim, output_shape_rank);
3239 }
3240 }
3241
3242 if (dim_numbers.start_index_map_size() !=
3243 start_indices_shape[dim_numbers.index_vector_dim()]) {
3244 return InvalidArgument(
3245 "Gather op has %d elements in start_index_map and the "
3246 "bound of dimension index_vector_dim=%d of start_indices is "
3247 "%d. These two numbers must be equal.",
3248 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3249 start_indices_shape[dim_numbers.index_vector_dim()]);
3250 }
3251
3252 for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3253 int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3254 if (operand_dim_for_start_index_i < 0 ||
3255 operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3256 return InvalidArgument(
3257 "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3258 input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3259 }
3260 }
3261
3262 std::vector<int64> sorted_start_index_map(
3263 dim_numbers.start_index_map().begin(),
3264 dim_numbers.start_index_map().end());
3265
3266 absl::c_sort(sorted_start_index_map);
3267
3268 if (absl::c_adjacent_find(sorted_start_index_map) !=
3269 sorted_start_index_map.end()) {
3270 return InvalidArgument(
3271 "Repeated dimensions are not allowed in start_index_map; "
3272 "got: %s.",
3273 StrJoin(dim_numbers.start_index_map(), ", "));
3274 }
3275
3276 for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3277 if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3278 return InvalidArgument(
3279 "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3280 "%d), got: %d.",
3281 input_shape.dimensions_size(), collapsed_dim);
3282 }
3283 }
3284
3285 if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3286 return InvalidArgument(
3287 "collapsed_slice_dims in gather op must be sorted; got: %s",
3288 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3289 }
3290
3291 if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3292 dim_numbers.collapsed_slice_dims().end()) {
3293 return InvalidArgument(
3294 "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3295 "got: %s.",
3296 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3297 }
3298
3299 return Status::OK();
3300 }
3301
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)3302 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3303 const Shape& input_shape, const Shape& start_indices_shape,
3304 const GatherDimensionNumbers& gather_dim_numbers,
3305 absl::Span<const int64> slice_sizes) {
3306 TF_RETURN_IF_ERROR(
3307 ExpectArray(input_shape, "input tensor operand of gather op"));
3308 TF_RETURN_IF_ERROR(
3309 ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3310
3311 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3312 return InvalidArgument(
3313 "Gather indices parameter must be an integral tensor; got %s.",
3314 ShapeUtil::HumanString(start_indices_shape));
3315 }
3316
3317 // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3318 // index_vector_dim is rank(P). The bounds of this expanded shape is
3319 // stored in expanded_start_indices_shape.
3320
3321 if (start_indices_shape.dimensions_size() <
3322 gather_dim_numbers.index_vector_dim() ||
3323 gather_dim_numbers.index_vector_dim() < 0) {
3324 return InvalidArgument(
3325 "Gather index leaf dimension must be within [0, rank(start_indices) + "
3326 "1). rank(start_indices) is %d and gather index leaf dimension is "
3327 "%d.",
3328 start_indices_shape.dimensions_size(),
3329 gather_dim_numbers.index_vector_dim());
3330 }
3331
3332 std::vector<int64> expanded_start_indices_shape;
3333 // Also tracks if an output dimension is dynamic.
3334 std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3335 expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3336 expanded_start_indices_shape_dynamic_dimensions.reserve(
3337 start_indices_shape.dimensions_size());
3338 absl::c_copy(start_indices_shape.dimensions(),
3339 std::back_inserter(expanded_start_indices_shape));
3340 absl::c_copy(
3341 start_indices_shape.dynamic_dimensions(),
3342 std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3343 if (expanded_start_indices_shape.size() ==
3344 gather_dim_numbers.index_vector_dim()) {
3345 expanded_start_indices_shape.push_back(1);
3346 expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3347 }
3348
3349 TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3350 input_shape, expanded_start_indices_shape, gather_dim_numbers));
3351
3352 if (slice_sizes.size() != input_shape.dimensions_size()) {
3353 return InvalidArgument(
3354 "Gather op must have one slice size for every input dimension; got: "
3355 "len(slice_sizes)=%lu, input_shape.rank=%d.",
3356 slice_sizes.size(), input_shape.dimensions_size());
3357 }
3358
3359 if (slice_sizes.size() !=
3360 gather_dim_numbers.offset_dims_size() +
3361 gather_dim_numbers.collapsed_slice_dims_size()) {
3362 return InvalidArgument(
3363 "All components of the offset index in a gather op must either be a "
3364 "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3365 "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3366 slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3367 StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3368 }
3369
3370 for (int i = 0; i < slice_sizes.size(); i++) {
3371 int64 slice_size = slice_sizes[i];
3372 int64 corresponding_input_size = input_shape.dimensions(i);
3373 if (slice_size < 0 || slice_size > corresponding_input_size) {
3374 return InvalidArgument(
3375 "Slice size at index %d in gather op is out of range, must be "
3376 "within [0, %d), got %d.",
3377 i, corresponding_input_size + 1, slice_size);
3378 }
3379 }
3380
3381 for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3382 if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3383 return InvalidArgument(
3384 "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3385 "is %d for index %d at position %d.",
3386 slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3387 gather_dim_numbers.collapsed_slice_dims(i), i);
3388 }
3389 }
3390
3391 int64 result_rank = gather_dim_numbers.offset_dims_size() +
3392 (expanded_start_indices_shape.size() - 1);
3393 int64 offset_dims_seen = 0;
3394 int64 gather_dims_seen = 0;
3395 std::vector<int64> output_dim_bounds;
3396 output_dim_bounds.reserve(result_rank);
3397
3398 std::vector<bool> output_dim_is_dynamic;
3399 output_dim_is_dynamic.reserve(result_rank);
3400 for (int64 i = 0; i < result_rank; i++) {
3401 int64 current_bound;
3402 bool dim_dynamic = false;
3403 bool is_window_index =
3404 absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3405 if (is_window_index) {
3406 while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3407 offset_dims_seen)) {
3408 offset_dims_seen++;
3409 }
3410 // Gathering an entire dynamic dimension creates dynamic dimension.
3411 //
3412 // e.g.,:
3413 //
3414 // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3415 //
3416 // creates
3417 //
3418 // [<=2, 1]
3419 if (slice_sizes[offset_dims_seen] ==
3420 input_shape.dimensions(offset_dims_seen)) {
3421 dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3422 }
3423 current_bound = slice_sizes[offset_dims_seen++];
3424 } else {
3425 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3426 gather_dims_seen++;
3427 }
3428 // Forward dynamic dimensions from indices.
3429 dim_dynamic =
3430 expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3431
3432 current_bound = expanded_start_indices_shape[gather_dims_seen++];
3433 }
3434 output_dim_is_dynamic.push_back(dim_dynamic);
3435 output_dim_bounds.push_back(current_bound);
3436 }
3437
3438 return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3439 output_dim_is_dynamic);
3440 }
3441
3442 namespace {
3443
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3444 Status ValidateScatterDimensionNumbers(
3445 const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3446 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3447 // Validate update_window_dims in ScatterDimensionNumbers.
3448 if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3449 return InvalidArgument(
3450 "update_window_dims in scatter op must be sorted; got: %s.",
3451 StrJoin(dim_numbers.update_window_dims(), ", "));
3452 }
3453 if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3454 dim_numbers.update_window_dims().end()) {
3455 return InvalidArgument(
3456 "update_window_dims in scatter op must not repeat; got: %s.",
3457 StrJoin(dim_numbers.update_window_dims(), ", "));
3458 }
3459 const int64 updates_rank = updates_shape.rank();
3460 for (int64 window_dim : dim_numbers.update_window_dims()) {
3461 if (window_dim < 0 || window_dim >= updates_rank) {
3462 return InvalidArgument(
3463 "Invalid update_window_dims set in scatter op; valid range is [0, "
3464 "%d). got: %d.",
3465 updates_rank, window_dim);
3466 }
3467 }
3468
3469 // Validate inserted_window_dims in ScatterDimensionNumbers.
3470 if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3471 return InvalidArgument(
3472 "inserted_window_dims in scatter op must be sorted; got: %s.",
3473 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3474 }
3475 if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3476 dim_numbers.inserted_window_dims().end()) {
3477 return InvalidArgument(
3478 "inserted_window_dims in scatter op must not repeat; got: %s.",
3479 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3480 }
3481 for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
3482 if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3483 return InvalidArgument(
3484 "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3485 "%d), got: %d.",
3486 operand_shape.dimensions_size(), inserted_dim);
3487 }
3488 }
3489
3490 // Validate window size.
3491 auto window_size = dim_numbers.update_window_dims_size() +
3492 dim_numbers.inserted_window_dims_size();
3493 if (window_size != operand_shape.rank()) {
3494 return InvalidArgument(
3495 "Scatter op has window of size %d; doesn't match operand of rank %d.",
3496 window_size, operand_shape.rank());
3497 }
3498
3499 // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3500 if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3501 scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3502 return InvalidArgument(
3503 "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3504 "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3505 "These two numbers must be equal.",
3506 dim_numbers.scatter_dims_to_operand_dims_size(),
3507 dim_numbers.index_vector_dim(),
3508 scatter_indices_shape[dim_numbers.index_vector_dim()]);
3509 }
3510 for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3511 int64 scatter_dim_to_operand_dim =
3512 dim_numbers.scatter_dims_to_operand_dims(i);
3513 if (scatter_dim_to_operand_dim < 0 ||
3514 scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3515 return InvalidArgument(
3516 "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3517 "got: %d->%d.",
3518 operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3519 }
3520 }
3521 std::vector<int64> sorted_scatter_dims_to_operand_dims(
3522 dim_numbers.scatter_dims_to_operand_dims().begin(),
3523 dim_numbers.scatter_dims_to_operand_dims().end());
3524 absl::c_sort(sorted_scatter_dims_to_operand_dims);
3525 if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3526 sorted_scatter_dims_to_operand_dims.end()) {
3527 return InvalidArgument(
3528 "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3529 "got: %s.",
3530 StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3531 }
3532
3533 return Status::OK();
3534 }
3535
3536 } // namespace
3537
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3538 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3539 const Shape& operand_shape, const Shape& scatter_indices_shape,
3540 const Shape& updates_shape, const ProgramShape& to_apply_shape,
3541 const ScatterDimensionNumbers& scatter_dim_numbers) {
3542 TF_RETURN_IF_ERROR(
3543 ExpectArray(operand_shape, "operand tensor of scatter op"));
3544 TF_RETURN_IF_ERROR(
3545 ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3546 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3547
3548 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3549 return InvalidArgument(
3550 "Scatter indices parameter must be an integral tensor; got %s.",
3551 ShapeUtil::HumanString(scatter_indices_shape));
3552 }
3553
3554 if (scatter_indices_shape.dimensions_size() <
3555 scatter_dim_numbers.index_vector_dim() ||
3556 scatter_dim_numbers.index_vector_dim() < 0) {
3557 return InvalidArgument(
3558 "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3559 " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3560 "is %d.",
3561 scatter_indices_shape.dimensions_size(),
3562 scatter_dim_numbers.index_vector_dim());
3563 }
3564
3565 // Check if the update computation has a proper shape as a reduction.
3566 const Shape init_value_shape =
3567 ShapeUtil::MakeShape(operand_shape.element_type(), {});
3568 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3569 {updates_shape.element_type()},
3570 /*inputs=*/1));
3571
3572 std::vector<int64> expanded_scatter_indices_shape =
3573 SpanToVector(scatter_indices_shape.dimensions());
3574 if (expanded_scatter_indices_shape.size() ==
3575 scatter_dim_numbers.index_vector_dim()) {
3576 expanded_scatter_indices_shape.push_back(1);
3577 }
3578
3579 int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3580 scatter_dim_numbers.update_window_dims_size();
3581 if (updates_shape.rank() != expected_updates_rank) {
3582 return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3583 expected_updates_rank, updates_shape.rank());
3584 }
3585
3586 TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3587 operand_shape, expanded_scatter_indices_shape, updates_shape,
3588 scatter_dim_numbers));
3589
3590 int64 inserted_dims_seen = 0;
3591 std::vector<int64> max_update_slice_sizes;
3592 for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3593 if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3594 scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3595 ++inserted_dims_seen;
3596 } else {
3597 max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3598 }
3599 }
3600 for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3601 auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3602 if (updates_shape.dimensions(update_window_dim) >
3603 max_update_slice_sizes[i]) {
3604 return InvalidArgument(
3605 "Bounds of the window dimensions of updates must not exceed the "
3606 "bounds of the corresponding dimensions of operand. For dimension "
3607 "%d, updates bound is %d, operand bound is %d.",
3608 update_window_dim, updates_shape.dimensions(update_window_dim),
3609 max_update_slice_sizes[i]);
3610 }
3611 }
3612
3613 int64 scatter_dims_seen = 0;
3614 for (int64 i = 0; i < updates_shape.rank(); ++i) {
3615 bool is_update_window_dim =
3616 absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3617 if (is_update_window_dim) {
3618 continue;
3619 }
3620 if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3621 ++scatter_dims_seen;
3622 }
3623 if (updates_shape.dimensions(i) !=
3624 expanded_scatter_indices_shape[scatter_dims_seen]) {
3625 return InvalidArgument(
3626 "Bounds of the scatter dimensions of updates must be same as the "
3627 "bounds of the corresponding dimensions of scatter indices. For "
3628 "scatter dimension %d, updates bound is %d, scatter_indices "
3629 "bound is %d.",
3630 i, updates_shape.dimensions(i),
3631 expanded_scatter_indices_shape[scatter_dims_seen]);
3632 }
3633 ++scatter_dims_seen;
3634 }
3635
3636 return operand_shape;
3637 }
3638
3639 } // namespace xla
3640