1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/permutation_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/math/math_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/statusor.h"
43
44 namespace xla {
45 namespace {
46
47 using absl::StrFormat;
48 using absl::StrJoin;
49
50 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)51 bool AllUnique(absl::Span<const int64> slice) {
52 return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
53 }
54
ExpectArray(const Shape & shape,absl::string_view op_type)55 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
56 if (!shape.IsArray()) {
57 return InvalidArgument("Expected array argument for %s, but got %s.",
58 string(op_type), ShapeUtil::HumanString(shape));
59 }
60 return Status::OK();
61 }
62
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64_t inputs)63 Status VerifyReducerShape(const ProgramShape& reducer_shape,
64 absl::Span<const Shape* const> init_value_shapes,
65 absl::Span<const PrimitiveType> input_element_types,
66 int64_t inputs) {
67 if (reducer_shape.parameters_size() != inputs * 2) {
68 return InvalidArgument(
69 "Reduction function must take %d parameters, but "
70 "takes %d parameter(s).",
71 inputs * 2, reducer_shape.parameters_size());
72 }
73
74 const Shape& accumulator_shape = reducer_shape.result();
75 std::vector<const Shape*> accumulator_subshapes;
76 if (accumulator_shape.IsArray()) {
77 if (inputs != 1) {
78 return InvalidArgument(
79 "Reduction function must produce a tuple with %d elements, but "
80 "produces a scalar",
81 inputs);
82 }
83 accumulator_subshapes.push_back(&accumulator_shape);
84 } else if (accumulator_shape.IsTuple()) {
85 if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
86 return InvalidArgument(
87 "Reduction function must produce a tuple with %d elements, but has "
88 "%d elements",
89 inputs, ShapeUtil::TupleElementCount(accumulator_shape));
90 }
91 for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
92 accumulator_subshapes.push_back(&element_shape);
93 }
94 } else {
95 return InvalidArgument(
96 "Reduction function must produce a scalar or tuple of scalars, but has "
97 "shape: %s",
98 ShapeUtil::HumanString(accumulator_shape));
99 }
100
101 for (const Shape* element_shape : accumulator_subshapes) {
102 if (element_shape->rank() != 0) {
103 return InvalidArgument(
104 "Reduction function must return a scalar or tuple of scalars but "
105 "returns shape: %s",
106 ShapeUtil::HumanString(accumulator_shape));
107 }
108 }
109
110 for (int64_t i = 0; i < inputs; ++i) {
111 // Check that the accumulator can be passed in as the first argument.
112 // Note: comparing here and below with Compatible since we don't care about
113 // layout in scalars - see b/26668201 for a longer-term vision.
114 if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
115 reducer_shape.parameters(i))) {
116 return InvalidArgument(
117 "Reduction function's %d-th parameter shape differs from the "
118 "result shape: %s vs %s",
119 i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
120 ShapeUtil::HumanString(*accumulator_subshapes[i]));
121 }
122 // Check that init_value's shapes are suitable for reducer_shape.
123 if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
124 *init_value_shapes[i])) {
125 return InvalidArgument(
126 "Reduction function's accumulator shape at index %d differs from "
127 "the init_value shape: %s vs %s",
128 i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
129 ShapeUtil::HumanString(*init_value_shapes[i]));
130 }
131 // Check that the inputs can be passed in as the non-accumulator arguments.
132 const Shape input_element_shape =
133 ShapeUtil::MakeShape(input_element_types[i], {});
134 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
135 input_element_shape, reducer_shape.parameters(inputs + i))) {
136 return InvalidArgument(
137 "Reduction function's %d-th parameter shape differs from the "
138 "input type element type: %s vs %s",
139 inputs + i,
140 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
141 ShapeUtil::HumanString(input_element_shape));
142 }
143 // Check that the accumulator and inputs to the reducer function match.
144 // If the accumulator is scalar, it must have the same type as the inputs
145 // (up to fp precision). If it is a tuple, then the k-th element of the
146 // tuple must have the same type as the K-th input (again, up to fp
147 // precision.)
148 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
149 *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
150 return InvalidArgument(
151 "Reduction function's %d-th parameter shape must "
152 "match the result shape, but got %s vs %s.",
153 inputs + i,
154 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
155 ShapeUtil::HumanString(*accumulator_subshapes[i]));
156 }
157 }
158
159 return Status::OK();
160 }
161
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)162 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
163 const Window& window,
164 PrimitiveType element_type,
165 bool allow_negative_padding) {
166 if (window.dimensions_size() != base_shape.rank()) {
167 return InvalidArgument(
168 "Window has dimension %d but base shape has dimension %d.",
169 window.dimensions_size(), base_shape.rank());
170 }
171
172 std::vector<int64> output_dimensions(window.dimensions_size());
173 std::vector<bool> output_is_dynamic(window.dimensions_size());
174 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
175 const auto& dim = window.dimensions(i);
176 if (dim.size() <= 0) {
177 return InvalidArgument("Window %s has a non-positive dimension.",
178 window.DebugString());
179 }
180 if (dim.stride() <= 0) {
181 return InvalidArgument("Window %s has a non-positive stride.",
182 window.DebugString());
183 }
184 if (!allow_negative_padding && dim.padding_low() < 0) {
185 return InvalidArgument("Window %s has a negative low padding.",
186 window.DebugString());
187 }
188 if (!allow_negative_padding && dim.padding_high() < 0) {
189 return InvalidArgument("Window %s has a negative high padding.",
190 window.DebugString());
191 }
192 if (dim.base_dilation() < 1) {
193 return InvalidArgument(
194 "Window %s has a non-positive base area dilation factor.",
195 window.DebugString());
196 }
197 if (dim.window_dilation() < 1) {
198 return InvalidArgument(
199 "Window %s has a non-positive window dilation factor.",
200 window.DebugString());
201 }
202
203 const int64_t dilated_base = window_util::DilatedBound(
204 ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
205 const int64_t padded_dilated_base =
206 dim.padding_low() + dilated_base + dim.padding_high();
207 const int64_t dilated_window =
208 window_util::DilatedBound(dim.size(), dim.window_dilation());
209
210 output_dimensions[i] = window_util::StridedBound(
211 padded_dilated_base, dilated_window, dim.stride());
212 output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
213 }
214
215 return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
216 output_is_dynamic);
217 }
218
MaybeUpcast(PrimitiveType from_type,absl::optional<PrimitiveType> preferred_element_type)219 StatusOr<PrimitiveType> MaybeUpcast(
220 PrimitiveType from_type,
221 absl::optional<PrimitiveType> preferred_element_type) {
222 if (!preferred_element_type.has_value() ||
223 *preferred_element_type == from_type) {
224 return from_type;
225 }
226 if (!primitive_util::IsFloatingPointType(from_type) &&
227 primitive_util::BitWidth(*preferred_element_type) <
228 primitive_util::BitWidth(from_type)) {
229 return InvalidArgument(
230 "`preferred_element_type` must not be narrower than the original "
231 "type.");
232 }
233 return *preferred_element_type;
234 }
235
236 } // namespace
237
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)238 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
239 HloOpcode opcode, const HloInstruction* operand) {
240 return InferUnaryOpShape(opcode, operand->shape());
241 }
242
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)243 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
244 HloOpcode opcode, const Shape& shape) {
245 // There is no copy operation at the proto level, so handle copy explicitly.
246 // A domain shape is the same as the input one.
247 if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
248 return shape;
249 }
250
251 TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
252
253 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
254 switch (opcode) {
255 case HloOpcode::kFloor:
256 case HloOpcode::kCeil:
257 case HloOpcode::kRoundNearestAfz:
258 if (!ShapeUtil::ElementIsFloating(shape)) {
259 return InvalidArgument(
260 "Expected element type in shape to be floating for %s operation; "
261 "got %s.",
262 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
263 }
264 return shape;
265 case HloOpcode::kCos:
266 case HloOpcode::kSin:
267 case HloOpcode::kExp:
268 case HloOpcode::kExpm1:
269 case HloOpcode::kLog:
270 case HloOpcode::kLog1p:
271 case HloOpcode::kLogistic:
272 case HloOpcode::kRsqrt:
273 case HloOpcode::kSqrt:
274 case HloOpcode::kCbrt:
275 case HloOpcode::kTanh:
276 if (!ShapeUtil::ElementIsFloating(shape) &&
277 !ShapeUtil::ElementIsComplex(shape)) {
278 return InvalidArgument(
279 "Expected element type in shape to be floating or complex for %s "
280 "operation; got %s.",
281 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
282 }
283 return shape;
284 case HloOpcode::kReal:
285 case HloOpcode::kImag:
286 if (ShapeUtil::ElementIsComplex(shape)) {
287 return ShapeUtil::ComplexComponentShape(shape);
288 } else if (ShapeUtil::ElementIsFloating(shape)) {
289 return shape;
290 } else {
291 return InvalidArgument(
292 "Expected element type in shape to be floating or complex for "
293 "%s operation; got %s.",
294 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
295 }
296 case HloOpcode::kAbs:
297 if (ShapeUtil::ElementIsComplex(shape)) {
298 return ShapeUtil::ChangeElementType(
299 shape, primitive_util::ComplexComponentType(shape.element_type()));
300 } else if (ShapeUtil::ElementIsSigned(shape)) {
301 return shape;
302 } else {
303 return InvalidArgument(
304 "Expected element type in shape to be floating or complex for "
305 "%s operation; got %s.",
306 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
307 }
308 case HloOpcode::kClz:
309 if (!ShapeUtil::ElementIsIntegral(shape)) {
310 return InvalidArgument(
311 "Expected an integral element type in argument to Clz "
312 "operation; got %s.",
313 PrimitiveType_Name(shape.element_type()));
314 }
315 return shape;
316 case HloOpcode::kNegate:
317 if (!ShapeUtil::ElementIsIntegral(shape) &&
318 !ShapeUtil::ElementIsFloating(shape) &&
319 !ShapeUtil::ElementIsComplex(shape)) {
320 return InvalidArgument(
321 "Expected element type in shape to be integral, floating or "
322 "complex for %s operation; got %s.",
323 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
324 }
325 return shape;
326 case HloOpcode::kPopulationCount:
327 if (!ShapeUtil::ElementIsIntegral(shape)) {
328 return InvalidArgument(
329 "Expected an integral element type in argument to PopulationCount "
330 "operation; got %s.",
331 PrimitiveType_Name(shape.element_type()));
332 }
333 return shape;
334 case HloOpcode::kSign:
335 if (!ShapeUtil::ElementIsSigned(shape) &&
336 !ShapeUtil::ElementIsComplex(shape)) {
337 return InvalidArgument(
338 "Expected element type in shape to be signed or complex for "
339 "%s operation; got %s.",
340 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
341 }
342 return shape;
343
344 case HloOpcode::kNot:
345 if (shape.element_type() != PRED &&
346 !primitive_util::IsIntegralType(shape.element_type())) {
347 return InvalidArgument(
348 "Expected pred or an integral element type in argument to Not "
349 "operation; got %s.",
350 PrimitiveType_Name(shape.element_type()));
351 }
352 return shape;
353
354 case HloOpcode::kIsFinite:
355 if (!ShapeUtil::ElementIsFloating(shape)) {
356 return InvalidArgument(
357 "Expected element type in shape to be floating "
358 "point for IsFinite "
359 "operation; got %s.",
360 PrimitiveType_Name(shape.element_type()));
361 }
362 return ShapeUtil::ChangeElementType(shape, PRED);
363
364 default:
365 return InvalidArgument(
366 "Unknown operation for unary shape inference: \"%s\".",
367 HloOpcodeString(opcode));
368 }
369 }
370
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64_t dimension)371 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
372 absl::Span<const Shape* const> arg_shapes, const int64_t dimension) {
373 if (arg_shapes.empty()) {
374 return InvalidArgument("Concatenate expects at least one argument.");
375 }
376 if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
377 return InvalidArgument("Concatenate dimension out of bounds: %d.",
378 dimension);
379 }
380 const Shape* arg_shape = nullptr;
381 PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
382 for (const Shape* shape : arg_shapes) {
383 TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
384 if (!arg_shape) {
385 arg_shape = shape;
386 element_type = arg_shape->element_type();
387 continue;
388 }
389 if (arg_shape->rank() != shape->rank()) {
390 return InvalidArgument(
391 "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
392 "(%s).",
393 arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
394 ShapeUtil::HumanString(*shape));
395 }
396 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
397 return InvalidArgument(
398 "Cannot concatenate arrays with different element types: %s vs %s.",
399 PrimitiveType_Name(arg_shape->element_type()),
400 PrimitiveType_Name(shape->element_type()));
401 }
402 for (int64_t dimension_number = 0; dimension_number < arg_shape->rank();
403 ++dimension_number) {
404 if (arg_shape->dimensions(dimension_number) !=
405 shape->dimensions(dimension_number)) {
406 if (dimension_number == dimension) {
407 continue; // It's okay to differ in the dimension we're
408 // concatenating.
409 }
410 return InvalidArgument(
411 "Cannot concatenate arrays that differ in dimensions other than "
412 "the one being concatenated (the other array dimensions must be "
413 "the same): %s vs %s in dimension %d.",
414 ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
415 dimension);
416 }
417 }
418 element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
419 }
420
421 std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
422 arg_shape->dimensions().end());
423 for (size_t i = 1; i < arg_shapes.size(); ++i) {
424 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
425 }
426
427 Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
428
429 // Set dynamic dimensions if any input has dynamic dimension.
430 for (const Shape* shape : arg_shapes) {
431 for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
432 if (shape->is_dynamic_dimension(i)) {
433 result.set_dynamic_dimension(i, true);
434 }
435 }
436 }
437 return result;
438 }
439
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)440 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
441 const Shape& operand_shape, PrimitiveType new_element_type) {
442 auto old_element_type = operand_shape.element_type();
443 if (primitive_util::IsComplexType(old_element_type) &&
444 !primitive_util::IsComplexType(new_element_type)) {
445 return Unimplemented(
446 "Conversion from complex to real type %s => %s is not implemented.",
447 ShapeUtil::HumanString(operand_shape),
448 PrimitiveType_Name(new_element_type));
449 }
450 if (!operand_shape.IsArray() ||
451 !primitive_util::IsArrayType(new_element_type)) {
452 // Note: we may want to support tuple conversions via this operation in the
453 // future, by recursing into the tuple elements to check all sub-conversions
454 // are valid. For now we just reject them, though.
455 return InvalidArgument(
456 "Convert does not allow non-arrays, so cannot convert from %s to %s.",
457 ShapeUtil::HumanString(operand_shape),
458 PrimitiveType_Name(new_element_type));
459 }
460
461 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
462 }
463
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)464 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
465 const Shape& operand_shape, PrimitiveType new_element_type) {
466 auto old_element_type = operand_shape.element_type();
467 if (primitive_util::IsComplexType(old_element_type) !=
468 primitive_util::IsComplexType(new_element_type)) {
469 return InvalidArgument("Conversion from complex to real type %s => %s.",
470 ShapeUtil::HumanString(operand_shape),
471 PrimitiveType_Name(new_element_type));
472 }
473 if (!operand_shape.IsArray() ||
474 !primitive_util::IsArrayType(new_element_type)) {
475 // Note: we may want to support tuple conversions via this operation in the
476 // future, by recursing into the tuple elements to check all sub-conversions
477 // are valid. For now we just reject them, though.
478 return InvalidArgument(
479 "Cannot convert from or to tuple type; requested conversion: %s => %s.",
480 ShapeUtil::HumanString(operand_shape),
481 PrimitiveType_Name(new_element_type));
482 }
483 if (primitive_util::BitWidth(old_element_type) !=
484 primitive_util::BitWidth(new_element_type)) {
485 return InvalidArgument(
486 "Cannot bitcast types with different bit-widths: %s => %s.",
487 PrimitiveType_Name(old_element_type),
488 PrimitiveType_Name(new_element_type));
489 }
490
491 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
492 }
493
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)494 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
495 const Shape& operand_shape, const int exponent_bits,
496 const int mantissa_bits) {
497 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
498 return InvalidArgument(
499 "Expected element type in shape to be floating point for "
500 "ReducePrecision operation; got %s.",
501 PrimitiveType_Name(operand_shape.element_type()));
502 }
503 if (exponent_bits < 1) {
504 // One exponent bit is necessary to distinguish 0 from infinity. Having
505 // no exponent bits doesn't produce a sensible number, so we require at
506 // least one.
507 return InvalidArgument("Expected exponent_bits >= 1; got %d.",
508 exponent_bits);
509 }
510 if (mantissa_bits < 0) {
511 // A number with no mantissa bits is still meaningful, however.
512 return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
513 mantissa_bits);
514 }
515 return operand_shape;
516 }
517
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)518 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
519 const Shape& operand_shape, const Shape& padding_value_shape,
520 const PaddingConfig& padding_config) {
521 if (!operand_shape.IsArray()) {
522 return InvalidArgument(
523 "Pad operation does not support tuple-shape operands.");
524 }
525 if (!ShapeUtil::IsScalar(padding_value_shape)) {
526 return InvalidArgument(
527 "Pad operation does not support non-scalar padding values.");
528 }
529 if (operand_shape.rank() != padding_config.dimensions_size()) {
530 return InvalidArgument(
531 "The rank of the operand and the padding configuration do not match: "
532 "%s vs %s.",
533 ShapeUtil::HumanString(operand_shape),
534 padding_config.ShortDebugString());
535 }
536 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
537 padding_value_shape)) {
538 return InvalidArgument(
539 "The element types of the operands to Pad do not match.");
540 }
541 if (absl::c_any_of(padding_config.dimensions(),
542 [](const PaddingConfig::PaddingConfigDimension& p) {
543 return p.interior_padding() < 0;
544 })) {
545 return InvalidArgument("Interior padding cannot be negative: %s",
546 padding_config.ShortDebugString());
547 }
548
549 if (!padding_value_shape.is_static()) {
550 return InvalidArgument("Dynamic padding value is not supported");
551 }
552
553 std::vector<int64> dimensions(operand_shape.rank());
554 std::vector<bool> is_dynamic(operand_shape.rank());
555 for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
556 const auto& p = padding_config.dimensions(i);
557 dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
558 p.edge_padding_high() +
559 std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
560 p.interior_padding();
561 if (dimensions[i] < 0) {
562 return InvalidArgument("Padding result in negative size for dimension %d",
563 i);
564 }
565 is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
566 }
567
568 return ShapeUtil::MakeShape(
569 ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
570 dimensions, is_dynamic);
571 }
572
573 // Current DotDimensionNumbers Requirements:
574 //
575 // Contracting Dimensions:
576 // *) Same number of contracting dimensions on both lhs and rhs.
577 // *) Contracting dimension size must be the same on both lhs and rhs.
578 //
579 // Batch Dimensions:
580 // *) Same number of batch dimensions on both lhs and rhs.
581 // *) Same batch dimension sizes on both lhs and rhs.
582 //
583
584 namespace {
585
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)586 Status ValidateDotDimensionNumbers(
587 const Shape& lhs, const Shape& rhs,
588 const DotDimensionNumbers& dimension_numbers) {
589 // Check that dimension numbers are in range.
590 auto dims_in_range = [](const int64_t rank,
591 absl::Span<const int64> contracting_dims,
592 absl::Span<const int64> batch_dims) -> bool {
593 auto in_range = [&rank](int64_t i) -> bool { return 0 <= i && i < rank; };
594 return absl::c_all_of(contracting_dims, in_range) &&
595 absl::c_all_of(batch_dims, in_range);
596 };
597
598 absl::Span<const int64> lhs_contracting_dimensions =
599 AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
600 absl::Span<const int64> rhs_contracting_dimensions =
601 AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
602 absl::Span<const int64> lhs_batch_dimensions =
603 AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
604 absl::Span<const int64> rhs_batch_dimensions =
605 AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
606
607 if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
608 lhs_batch_dimensions) ||
609 !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
610 rhs_batch_dimensions)) {
611 return InvalidArgument("A dimension number is out of range in Dot: %s.",
612 dimension_numbers.DebugString());
613 }
614
615 // Check that dimension numbers are unique.
616 auto dims_unique = [](absl::Span<const int64> contracting_dims,
617 absl::Span<const int64> batch_dims) -> bool {
618 absl::flat_hash_set<int64> dim_set;
619 auto is_unique = [&dim_set](int64_t i) -> bool {
620 return dim_set.insert(i).second;
621 };
622 return absl::c_all_of(contracting_dims, is_unique) &&
623 absl::c_all_of(batch_dims, is_unique);
624 };
625
626 if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
627 !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
628 return InvalidArgument("A dimension number is not unique in Dot: %s.",
629 dimension_numbers.DebugString());
630 }
631
632 return Status::OK();
633 }
634
635 } // namespace
636
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers,absl::optional<PrimitiveType> preferred_element_type)637 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
638 const Shape& lhs, const Shape& rhs,
639 const DotDimensionNumbers& dimension_numbers,
640 absl::optional<PrimitiveType> preferred_element_type) {
641 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
642 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
643
644 auto fail = [lhs, rhs](const string& addendum) -> Status {
645 string message =
646 StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
647 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
648 if (!addendum.empty()) {
649 message += " " + addendum;
650 }
651 return InvalidArgument("%s", message);
652 };
653
654 // Validate basic properties of dot dimension numbers.
655 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
656
657 // Check that number of contracting dimensions match.
658 if (dimension_numbers.lhs_contracting_dimensions_size() !=
659 dimension_numbers.rhs_contracting_dimensions_size()) {
660 return fail(
661 "Must specify the same number of contracting dimensions for lhs and "
662 "rhs.");
663 }
664 // Check that contracting dimension sizes match.
665 for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
666 ++i) {
667 const int64_t lhs_contracting_dimension =
668 dimension_numbers.lhs_contracting_dimensions(i);
669 const int64_t rhs_contracting_dimension =
670 dimension_numbers.rhs_contracting_dimensions(i);
671 if (lhs.dimensions(lhs_contracting_dimension) !=
672 rhs.dimensions(rhs_contracting_dimension)) {
673 return fail("Contracting dimension sizes do not match.");
674 }
675 }
676
677 // Check that number of batch dimensions match.
678 if (dimension_numbers.lhs_batch_dimensions_size() !=
679 dimension_numbers.rhs_batch_dimensions_size()) {
680 return fail("Must the same number of batch dimensions for lhs and rhs.");
681 }
682
683 // Check that batch dimension numbers and sizes match.
684 for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
685 if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
686 rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
687 return fail("Batch dimension sizes must match for lhs/rhs.");
688 }
689 }
690
691 // The ranks of lhs and rhs are decremented by 1 respectively due to the
692 // contraction, and added for the rank of the result. When an input tensor is
693 // a scalar, its contribution to the rank of the result is 0.
694 // Generate the result dimensions in order, rhs dimensions followed by lhs
695 // dimensions except the contracted and batch dimensions.
696 std::vector<int64> dimensions;
697 std::vector<bool> is_dynamic;
698 for (int64_t lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
699 dimensions.push_back(lhs.dimensions(lhs_dim));
700 is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
701 }
702 for (int64_t i = 0; i < lhs.rank(); i++) {
703 if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
704 i) &&
705 !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
706 dimensions.push_back(lhs.dimensions(i));
707 is_dynamic.push_back(lhs.is_dynamic_dimension(i));
708 }
709 }
710 for (int64_t i = 0; i < rhs.rank(); i++) {
711 if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
712 i) &&
713 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
714 dimensions.push_back(rhs.dimensions(i));
715 is_dynamic.push_back(rhs.is_dynamic_dimension(i));
716 }
717 }
718 TF_ASSIGN_OR_RETURN(
719 PrimitiveType type,
720 MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
721 preferred_element_type));
722 Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
723
724 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
725 VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
726 return result;
727 }
728
729 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)730 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
731 const Shape& lhs,
732 const Shape& rhs) {
733 TF_RET_CHECK(lhs.rank() == rhs.rank());
734
735 // The shapes have to be compatible. That is, if some dimension d has a
736 // different size in the two shapes, one of them has to be 1 (a "degenerate"
737 // dimension). In that case, the output shape has the non-1 dimension size
738 // from the lhs/rhs pair in every index.
739 std::vector<int64> output_dimensions(lhs.rank());
740 std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
741 for (int64_t i = 0; i < lhs.rank(); ++i) {
742 if (lhs.dimensions(i) == rhs.dimensions(i)) {
743 output_dimensions[i] = lhs.dimensions(i);
744 } else if (lhs.dimensions(i) == 1) {
745 output_dimensions[i] = rhs.dimensions(i);
746 } else if (rhs.dimensions(i) == 1) {
747 output_dimensions[i] = lhs.dimensions(i);
748 } else {
749 return InvalidArgument(
750 "Binary op %s with incompatible shapes: %s and %s.",
751 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
752 ShapeUtil::HumanString(rhs));
753 }
754 }
755
756 // Merge dynamic dimensions from two shapes.
757 for (int64_t i = 0; i < rhs.rank(); ++i) {
758 if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
759 output_dimensions_is_dynamic[i] = true;
760 }
761 }
762
763 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
764 output_dimensions, output_dimensions_is_dynamic);
765 }
766
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)767 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
768 const Shape& smaller_shape, const Shape& larger_shape,
769 absl::Span<const int64> broadcast_dimensions) {
770 if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
771 // Reject "magic" inference for binops on different shapes, requiring
772 // the user to provide an explicit broadcast dimension in this case.
773 // See b/25177275 for more details.
774 return InvalidArgument("Automatic shape inference not supported: %s and %s",
775 ShapeUtil::HumanString(smaller_shape),
776 ShapeUtil::HumanString(larger_shape));
777 } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
778 return InvalidArgument(
779 "Size of broadcast_dimensions has to match lower-rank operand's "
780 "rank; "
781 " lower-rank operand's rank is %d, size of broadcast_dimensions is "
782 "%u.",
783 smaller_shape.rank(), broadcast_dimensions.size());
784 }
785
786 // broadcast_dimensions is a sequence of dimensions; its length is equal to
787 // the rank of the lower-rank operand. The lower-rank operand's dimensions
788 // have to be compatible with the higher-rank operand's dimensions at indices
789 // specified by broadcast_dimensions. Here compatible means the dimension
790 // sizes are equal or in one of the shapes the dimension size is
791 // one. Examples:
792 //
793 // smaller_shape larger_shape broadcast_dimensions output_shape
794 // [] [2, 3] {} [2, 3]
795 // [3] [4, 3] {1} [4, 3]
796 // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
797 // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
798 // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
799 //
800 // The column output_shape may not be the final shape of the XLA
801 // operation. After the "InDim" broadcasting implemented in this function
802 // expands the rank, degenerate-dimension broadcasting (implemented in
803 // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
804 // up to match the dimension size of the other operand. For example, consider
805 // the row in the table above with a smaller_shape of [2, 1]. The shape
806 // returned by this function is [2, 3, 1] (output_shape) however, the result
807 // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
808 // broadcasting.
809 //
810 // Invalid broadcasts:
811 //
812 // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
813 // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
814 // dimension zero of smaller_shape(size 3). **Zero here comes from the value
815 // in broadcast_dimensions.
816 //
817 // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
818 // Reason: Dimension one of larger_shape (size 3) is not compatible with
819 // dimension zero of smaller_shape(size 2)
820
821 // The output shape is initially the larger_shape. Sizes of dimensions
822 // specified in broadcast_dimensions are then changed to match the
823 // corresponding dimension size in smaller_shape.
824 Shape output_shape(larger_shape);
825 output_shape.set_element_type(
826 ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
827
828 for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
829 int64_t dimension_to_match = broadcast_dimensions.at(i);
830 if (dimension_to_match < 0) {
831 return InvalidArgument(
832 "Broadcast dimension number (%d) cannot be negative.",
833 dimension_to_match);
834 }
835 if (dimension_to_match >= larger_shape.dimensions_size()) {
836 return InvalidArgument(
837 "Broadcast dimension number (%d) too large; higher-rank "
838 "operand has rank %d.",
839 dimension_to_match, larger_shape.dimensions_size());
840 }
841 int64_t small_dimension_size = smaller_shape.dimensions(i);
842 int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match);
843 bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
844 bool large_is_dynamic =
845 larger_shape.is_dynamic_dimension(dimension_to_match);
846 // Dimension sizes must be compatible: match or be degenerate (degenerate
847 // case is handled by degenerate dimension broadcasting which occurs after
848 // InDim broadcasting).
849 if (small_dimension_size != large_dimension_size &&
850 small_dimension_size != 1 && large_dimension_size != 1) {
851 return InvalidArgument(
852 "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
853 small_dimension_size, large_dimension_size,
854 ShapeUtil::HumanString(smaller_shape),
855 ShapeUtil::HumanString(larger_shape));
856 }
857 if (small_is_dynamic != large_is_dynamic) {
858 if (small_dimension_size == large_dimension_size ||
859 (small_dimension_size == 1 && !small_is_dynamic) ||
860 (large_dimension_size == 1 && !large_is_dynamic)) {
861 // Do nothing. It's OK when the size-1 dimension is not static.
862 } else {
863 return InvalidArgument(
864 "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
865 ShapeUtil::HumanString(smaller_shape),
866 ShapeUtil::HumanString(larger_shape));
867 }
868 }
869 // Make sure the broadcast dimensions are listed in a strictly increasing
870 // order.
871 if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
872 return InvalidArgument(
873 "Broadcast dimensions order is wrong: %d comes after %d.",
874 dimension_to_match, broadcast_dimensions.at(i - 1));
875 }
876
877 output_shape.set_dimensions(dimension_to_match, small_dimension_size);
878 output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
879 }
880
881 return output_shape;
882 }
883
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)884 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
885 HloOpcode operation, const Shape& lhs, const Shape& rhs,
886 absl::Span<const int64> broadcast_dimensions) {
887 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
888 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
889
890 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
891 return InvalidArgument(
892 "Binary op %s with different element types: %s and %s.",
893 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
894 ShapeUtil::HumanString(rhs));
895 }
896
897 if (lhs.rank() == rhs.rank()) {
898 std::vector<int64> identity_dims(lhs.rank());
899 std::iota(identity_dims.begin(), identity_dims.end(), 0);
900 if (!broadcast_dimensions.empty() &&
901 broadcast_dimensions != identity_dims) {
902 return InvalidArgument(
903 "Broadcast dimensions field must either be not set or be the "
904 "identity on binary operations with operands of the same rank.");
905 }
906 }
907
908 if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
909 // If the shapes are the same other than layout, the output shape is the
910 // same (elementwise op).
911 Shape result = ShapeUtil::ChangeElementType(
912 lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
913
914 for (int64_t i = 0; i < rhs.rank(); ++i) {
915 if (rhs.is_dynamic_dimension(i)) {
916 result.set_dynamic_dimension(i, true);
917 }
918 }
919
920 return result;
921
922 } else if (lhs.rank() == rhs.rank()) {
923 return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
924 } else {
925 // Ranks do not match, so perform InDim broadcasting using
926 // broadcast_dimensions. Scalar broadcasting is a special case of this.
927 const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
928 const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
929
930 // After InDim broadcasting, perform degenerate dimensions broadcasting.
931 TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
932 InferInDimBroadcastShape(smaller_shape, larger_shape,
933 broadcast_dimensions));
934
935 return InferDegenerateDimensionBroadcastShape(
936 operation, indim_broadcast_shape, larger_shape);
937 }
938 }
939
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)940 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
941 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
942 return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
943 /*broadcast_dimensions=*/{});
944 }
945
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)946 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
947 HloOpcode opcode, const Shape& lhs, const Shape& rhs,
948 absl::Span<const int64> broadcast_dimensions) {
949 VLOG(2) << StrFormat(
950 "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
951 HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
952 ShapeUtil::HumanStringWithLayout(rhs),
953 StrJoin(broadcast_dimensions, ", "));
954
955 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
956 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
957
958 TF_RETURN_IF_ERROR(ExpectArray(
959 lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
960 TF_RETURN_IF_ERROR(ExpectArray(
961 rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
962 switch (opcode) {
963 case HloOpcode::kAdd:
964 case HloOpcode::kMaximum:
965 case HloOpcode::kMinimum:
966 case HloOpcode::kMultiply:
967 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
968 broadcast_dimensions);
969
970 case HloOpcode::kSubtract:
971 case HloOpcode::kAtan2:
972 case HloOpcode::kPower:
973 case HloOpcode::kDivide:
974 case HloOpcode::kRemainder:
975 case HloOpcode::kShiftLeft:
976 case HloOpcode::kShiftRightArithmetic:
977 case HloOpcode::kShiftRightLogical:
978 if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
979 return InvalidArgument(
980 "Expected element type in shape to be arithmetic type for "
981 "operation %s; got PRED.",
982 HloOpcodeString(opcode));
983 }
984 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
985 broadcast_dimensions);
986
987 case HloOpcode::kComplex: {
988 if (!ShapeUtil::ElementIsFloating(lhs)) {
989 return InvalidArgument(
990 "Expected element type in shape to be floating for complex compose "
991 "operation; got %s.",
992 PrimitiveType_Name(lhs.element_type()));
993 }
994 TF_ASSIGN_OR_RETURN(const Shape& shape,
995 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
996 broadcast_dimensions));
997 if (lhs.element_type() == F32 && rhs.element_type() == F32) {
998 return ShapeUtil::ChangeElementType(shape, C64);
999 } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
1000 return ShapeUtil::ChangeElementType(shape, C128);
1001 } else {
1002 return Unimplemented("Complex component type is not implemented.");
1003 }
1004 }
1005 case HloOpcode::kAnd:
1006 case HloOpcode::kOr:
1007 case HloOpcode::kXor:
1008 if (lhs.element_type() != PRED &&
1009 !primitive_util::IsIntegralType(lhs.element_type())) {
1010 return InvalidArgument(
1011 "Expected pred or integral type in argument to and/or operation; "
1012 "got %s.",
1013 PrimitiveType_Name(lhs.element_type()));
1014 }
1015 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1016 broadcast_dimensions);
1017 case HloOpcode::kCompare: {
1018 TF_ASSIGN_OR_RETURN(const Shape& shape,
1019 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1020 broadcast_dimensions));
1021 return ShapeUtil::ChangeElementType(shape, PRED);
1022 }
1023 default:
1024 return Unimplemented(
1025 "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1026 HloOpcodeString(opcode), lhs.ShortDebugString(),
1027 rhs.ShortDebugString());
1028 }
1029 }
1030
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1031 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1032 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1033 const HloInstruction* ehs) {
1034 return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1035 }
1036
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1037 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1038 HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1039 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1040 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1041 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1042 switch (opcode) {
1043 case HloOpcode::kClamp:
1044 return InferClampShape(lhs, rhs, ehs);
1045 case HloOpcode::kSelect:
1046 return InferSelectShape(lhs, rhs, ehs);
1047 case HloOpcode::kTupleSelect:
1048 return InferTupleSelectShape(lhs, rhs, ehs);
1049 default:
1050 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1051 }
1052 }
1053
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1054 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1055 HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1056 std::vector<const Shape*> operand_shapes;
1057 operand_shapes.reserve(operands.size());
1058 for (const HloInstruction* operand : operands) {
1059 operand_shapes.push_back(&operand->shape());
1060 }
1061 return InferVariadicOpShape(opcode, operand_shapes);
1062 }
1063
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1064 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1065 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1066 for (const Shape* shape : operand_shapes) {
1067 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1068 }
1069 switch (opcode) {
1070 case HloOpcode::kTuple: {
1071 Shape result = ShapeUtil::MakeTupleShape({});
1072 result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1073 for (const Shape* shape : operand_shapes) {
1074 ShapeUtil::AppendShapeToTuple(*shape, &result);
1075 }
1076 return result;
1077 }
1078 case HloOpcode::kSort: {
1079 if (operand_shapes.size() == 1) {
1080 return *operand_shapes[0];
1081 } else {
1082 for (int64_t operand = 1; operand < operand_shapes.size(); ++operand) {
1083 if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1084 *operand_shapes[operand])) {
1085 return InvalidArgument(
1086 "Sort keys and values dimensions must match. "
1087 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1088 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1089 ShapeUtil::HumanString(*operand_shapes[operand]));
1090 }
1091 }
1092 std::vector<Shape> operand_shape_values;
1093 for (const Shape* operand_shape : operand_shapes) {
1094 operand_shape_values.push_back(*operand_shape);
1095 }
1096 return ShapeUtil::MakeTupleShape(operand_shape_values);
1097 }
1098 return InvalidArgument("Unexpected number of operands for sort");
1099 }
1100 default:
1101 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1102 }
1103 }
1104
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1105 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1106 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1107 absl::Span<const int64> dimensions) {
1108 if (arg_shapes.empty()) {
1109 return InvalidArgument("Map expects at least one argument.");
1110 }
1111
1112 // All arguments must have the same shape.
1113 const Shape* arg_shape = arg_shapes[0];
1114 for (size_t i = 1; i < arg_shapes.size(); ++i) {
1115 TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1116
1117 if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1118 continue;
1119 }
1120 if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1121 *arg_shape)) {
1122 if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1123 continue;
1124 }
1125 if (ShapeUtil::IsScalar(*arg_shape)) {
1126 arg_shape = arg_shapes[i];
1127 continue;
1128 }
1129 }
1130
1131 std::vector<string> pieces;
1132 for (const Shape* shape : arg_shapes) {
1133 pieces.push_back(ShapeUtil::HumanString(*shape));
1134 }
1135 return InvalidArgument(
1136 "Map operation requires all operands to have the same shape; got: "
1137 "%s.",
1138 StrJoin(pieces, ", "));
1139 }
1140
1141 // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1142 // only support mapping across all dimensions: i.e. scalar map functions).
1143 if (dimensions.size() != arg_shape->dimensions_size()) {
1144 return InvalidArgument(
1145 "Map applied to a subset of dimensions currently not supported: "
1146 "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1147 arg_shape->dimensions_size(), dimensions.size());
1148 }
1149
1150 // Check that requested map dimensions numbers are monotonically increasing.
1151 for (int i = 0; i < dimensions.size(); ++i) {
1152 if (dimensions[i] != i) {
1153 return InvalidArgument(
1154 "Map requires monotonically increasing dimension numbers; got: %s.",
1155 StrJoin(dimensions, ", "));
1156 }
1157 }
1158
1159 // The applied function's arity equals the number of arguments.
1160 if (arg_shapes.size() != to_apply.parameters_size()) {
1161 return InvalidArgument(
1162 "Map applied function arity must match number of arguments; got: "
1163 "arity: %d, arguments: %u.",
1164 to_apply.parameters_size(), arg_shapes.size());
1165 }
1166
1167 // The parameters should all be scalars, and the output too.
1168 const Shape& output_shape = to_apply.result();
1169 if (!ShapeUtil::IsScalar(output_shape)) {
1170 return InvalidArgument(
1171 "Mapped computation's result has to be a scalar; got: %s.",
1172 ShapeUtil::HumanString(output_shape));
1173 }
1174
1175 for (int i = 0; i < to_apply.parameters_size(); ++i) {
1176 const Shape& parameter_shape = to_apply.parameters(i);
1177
1178 if (!ShapeUtil::IsScalar(parameter_shape)) {
1179 return InvalidArgument(
1180 "Mapped computation's parameter has to be a scalar; "
1181 "got parameter %d shape: %s.",
1182 i, ShapeUtil::HumanString(parameter_shape));
1183 }
1184
1185 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1186 *arg_shape)) {
1187 return InvalidArgument(
1188 "Mapped computation's parameter type has to match argument element "
1189 "type; got parameter %d shape: %s, argument shape: %s.",
1190 i, ShapeUtil::HumanString(parameter_shape),
1191 ShapeUtil::HumanString(*arg_shape));
1192 }
1193 }
1194
1195 return ShapeUtil::MakeShape(output_shape.element_type(),
1196 AsInt64Slice(arg_shape->dimensions()));
1197 }
1198
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64_t feature_index)1199 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1200 const Shape& operand_shape, const Shape& scale_shape,
1201 const Shape& offset_shape, int64_t feature_index) {
1202 TF_RETURN_IF_ERROR(
1203 ExpectArray(operand_shape, "operand of batch norm training"));
1204 TF_RETURN_IF_ERROR(
1205 ExpectArray(offset_shape, "offset input of batch norm training"));
1206 TF_RETURN_IF_ERROR(
1207 ExpectArray(scale_shape, "scale input of batch norm training"));
1208
1209 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1210 Status::OK());
1211 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1212 Status::OK());
1213 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1214 Status::OK());
1215
1216 if (feature_index >= operand_shape.rank()) {
1217 return InvalidArgument(
1218 "Expected feature_index of batch-norm-training to be "
1219 "smaller than the rank of operand_shape; "
1220 "got feature_index %d, and rank %d.",
1221 feature_index, operand_shape.rank());
1222 }
1223
1224 if (feature_index < 0) {
1225 return InvalidArgument(
1226 "Expected feature_index of batch-norm-training to "
1227 "be a non-negative number, got %d.",
1228 feature_index);
1229 }
1230
1231 if (operand_shape.rank() < 1) {
1232 return InvalidArgument(
1233 "Expected the rank of operand to "
1234 "batch-norm-training to be at least 1; got %d.",
1235 operand_shape.rank());
1236 }
1237
1238 if (offset_shape.rank() != 1) {
1239 return InvalidArgument(
1240 "Offset input of batch-norm-training must have"
1241 " rank 1, but has rank %d.",
1242 offset_shape.rank());
1243 }
1244
1245 if (scale_shape.rank() != 1) {
1246 return InvalidArgument(
1247 "Scale input of batch-norm-training must have"
1248 " rank 1, but has rank %d.",
1249 scale_shape.rank());
1250 }
1251
1252 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1253 return InvalidArgument(
1254 "The operand to batch-norm-training must have a floating point "
1255 "element type, but the shape is %s.",
1256 PrimitiveType_Name(operand_shape.element_type()));
1257 }
1258
1259 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1260 operand_shape)) {
1261 return InvalidArgument(
1262 "The inputs should have the same element type for batch-norm-training, "
1263 "but the shape of offset factor is %s "
1264 "and the shape of operand is %s.",
1265 PrimitiveType_Name(offset_shape.element_type()),
1266 PrimitiveType_Name(operand_shape.element_type()));
1267 }
1268
1269 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1270 operand_shape)) {
1271 return InvalidArgument(
1272 "The inputs should have the same element type for batch-norm-training, "
1273 "but the shape of scale factor is %s "
1274 "and the shape of operand is %s.",
1275 PrimitiveType_Name(scale_shape.element_type()),
1276 PrimitiveType_Name(operand_shape.element_type()));
1277 }
1278
1279 const int64_t feature_count = operand_shape.dimensions(feature_index);
1280 Shape output_shape_for_mean_and_var =
1281 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1282
1283 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1284 return InvalidArgument(
1285 "The size of offset factor should be the same as feature count,"
1286 "but the size of offset factor is %d "
1287 "and the feature count is %d.",
1288 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1289 }
1290
1291 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1292 return InvalidArgument(
1293 "The size of scale factor should be the same as feature count,"
1294 "but the size of scale factor is %d "
1295 "and the feature count is %d.",
1296 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1297 }
1298
1299 return ShapeUtil::MakeTupleShape({operand_shape,
1300 output_shape_for_mean_and_var,
1301 output_shape_for_mean_and_var});
1302 }
1303
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64_t feature_index)1304 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1305 const Shape& operand_shape, const Shape& scale_shape,
1306 const Shape& offset_shape, const Shape& mean_shape,
1307 const Shape& variance_shape, int64_t feature_index) {
1308 TF_RETURN_IF_ERROR(
1309 ExpectArray(operand_shape, "operand of batch norm inference"));
1310 TF_RETURN_IF_ERROR(
1311 ExpectArray(offset_shape, "offset input of batch norm inference"));
1312 TF_RETURN_IF_ERROR(
1313 ExpectArray(scale_shape, "scale input of batch norm inference"));
1314
1315 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1316 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1317 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1318 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1319 TF_RETURN_IF_ERROR(
1320 ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1321
1322 if (feature_index >= operand_shape.rank()) {
1323 return InvalidArgument(
1324 "Expected feature_index of batch-norm-inference to be "
1325 "smaller than the rank of operand_shape; "
1326 "got feature_index %d, and rank %d.",
1327 feature_index, operand_shape.rank());
1328 }
1329
1330 if (feature_index < 0) {
1331 return InvalidArgument(
1332 "Expected feature_index of batch-norm-inference to "
1333 "be a non-negative number, got %d.",
1334 feature_index);
1335 }
1336
1337 if (operand_shape.rank() < 1) {
1338 return InvalidArgument(
1339 "Expected the rank of operand to "
1340 "batch-norm-inference to be at least 1; got %d.",
1341 operand_shape.rank());
1342 }
1343
1344 if (offset_shape.rank() != 1) {
1345 return InvalidArgument(
1346 "Offset input of batch-norm-inference must have"
1347 " rank 1, but has rank %d.",
1348 offset_shape.rank());
1349 }
1350
1351 if (scale_shape.rank() != 1) {
1352 return InvalidArgument(
1353 "Scale input of batch-norm-inference must have"
1354 " rank 1, but has rank %d.",
1355 scale_shape.rank());
1356 }
1357
1358 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1359 return InvalidArgument(
1360 "The operand to batch-norm-inference must have a floating point "
1361 "element type, but the shape is %s.",
1362 PrimitiveType_Name(operand_shape.element_type()));
1363 }
1364
1365 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1366 operand_shape)) {
1367 return InvalidArgument(
1368 "The inputs should have the same element type for "
1369 "batch-norm-inference, "
1370 "but the shape of offset factor is %s "
1371 "and the shape of operand is %s.",
1372 PrimitiveType_Name(offset_shape.element_type()),
1373 PrimitiveType_Name(operand_shape.element_type()));
1374 }
1375
1376 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1377 operand_shape)) {
1378 return InvalidArgument(
1379 "The inputs should have the same element type for "
1380 "batch-norm-inference, "
1381 "but the shape of scale factor is %s "
1382 "and the shape of operand is %s.",
1383 PrimitiveType_Name(scale_shape.element_type()),
1384 PrimitiveType_Name(operand_shape.element_type()));
1385 }
1386
1387 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1388 operand_shape)) {
1389 return InvalidArgument(
1390 "The inputs should have the same element type for "
1391 "batch-norm-inference, "
1392 "but the shape of mean is %s "
1393 "and the shape of operand is %s.",
1394 PrimitiveType_Name(mean_shape.element_type()),
1395 PrimitiveType_Name(operand_shape.element_type()));
1396 }
1397
1398 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1399 operand_shape)) {
1400 return InvalidArgument(
1401 "The inputs should have the same element type for "
1402 "batch-norm-inference, "
1403 "but the shape of variance is %s "
1404 "and the shape of operand is %s.",
1405 PrimitiveType_Name(mean_shape.element_type()),
1406 PrimitiveType_Name(variance_shape.element_type()));
1407 }
1408
1409 const int64_t feature_count = operand_shape.dimensions(feature_index);
1410 Shape output_shape_for_mean_and_var =
1411 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1412
1413 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1414 return InvalidArgument(
1415 "The size of offset factor should be the same as feature count,"
1416 "but the size of offset factor is %d "
1417 "and the feature count is %d.",
1418 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1419 }
1420
1421 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1422 return InvalidArgument(
1423 "The size of scale factor should be the same as feature count,"
1424 "but the size of scale factor is %d "
1425 "and the feature count is %d.",
1426 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1427 }
1428
1429 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1430 return InvalidArgument(
1431 "The size of mean should be the same as feature count,"
1432 "but the size of mean is %d "
1433 "and the feature count is %d.",
1434 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1435 }
1436
1437 if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1438 return InvalidArgument(
1439 "The size of variance should be the same as feature count,"
1440 "but the size of variance is %d "
1441 "and the feature count is %d.",
1442 ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1443 }
1444
1445 return operand_shape;
1446 }
1447
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64_t feature_index)1448 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1449 const Shape& operand_shape, const Shape& scale_shape,
1450 const Shape& mean_shape, const Shape& var_shape,
1451 const Shape& output_grad_shape, int64_t feature_index) {
1452 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1453 TF_RETURN_IF_ERROR(
1454 ExpectArray(scale_shape, "scale input of batch norm grad"));
1455 TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1456 TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1457 TF_RETURN_IF_ERROR(
1458 ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1459
1460 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1461 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1462 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1463 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1464 TF_RETURN_IF_ERROR(
1465 ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1466
1467 if (feature_index >= operand_shape.rank()) {
1468 return InvalidArgument(
1469 "Expected feature_index of batch-norm-grad to be "
1470 "smaller than the rank of operand_shape; "
1471 "got feature_index %d, and rank %d.",
1472 feature_index, operand_shape.rank());
1473 }
1474
1475 if (operand_shape.rank() != output_grad_shape.rank()) {
1476 return InvalidArgument(
1477 "Expected operand_shape of batch-norm-grad to have the same rank as"
1478 " output_grad_shape; got rank(oprand_shape) %d, and"
1479 " rank(output_grad_shape) %d.",
1480 operand_shape.rank(), output_grad_shape.rank());
1481 }
1482
1483 if (mean_shape.rank() != 1) {
1484 return InvalidArgument(
1485 "Mean input of batch-norm-grad must have"
1486 " rank 1, but has rank %d.",
1487 mean_shape.rank());
1488 }
1489
1490 if (scale_shape.rank() != 1) {
1491 return InvalidArgument(
1492 "Scale input of batch-norm-grad must have"
1493 " rank 1, but has rank %d.",
1494 scale_shape.rank());
1495 }
1496
1497 if (var_shape.rank() != 1) {
1498 return InvalidArgument(
1499 "Var input of batch-norm-grad must have"
1500 " rank 1, but has rank %d.",
1501 var_shape.rank());
1502 }
1503
1504 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1505 return InvalidArgument(
1506 "The operand to batch-norm-grad must have a floating point "
1507 "element type, but the shape is %s.",
1508 PrimitiveType_Name(operand_shape.element_type()));
1509 }
1510
1511 if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1512 return InvalidArgument(
1513 "The output_grad to batch-norm-grad must have a floating point "
1514 "element type, but the shape is %s.",
1515 PrimitiveType_Name(output_grad_shape.element_type()));
1516 }
1517
1518 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1519 operand_shape)) {
1520 return InvalidArgument(
1521 "The inputs should have the same element type for batch-norm-grad, "
1522 "but the element type of output_grad is %s "
1523 "and the element type of operand is %s.",
1524 PrimitiveType_Name(output_grad_shape.element_type()),
1525 PrimitiveType_Name(operand_shape.element_type()));
1526 }
1527
1528 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1529 operand_shape)) {
1530 return InvalidArgument(
1531 "The inputs should have the same element type for batch-norm-grad, "
1532 "but the element type of scale factor is %s "
1533 "and the element type of operand is %s.",
1534 PrimitiveType_Name(scale_shape.element_type()),
1535 PrimitiveType_Name(operand_shape.element_type()));
1536 }
1537
1538 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1539 operand_shape)) {
1540 return InvalidArgument(
1541 "The inputs should have the same element type for batch-norm-grad, "
1542 "but the element type of mean is %s "
1543 "and the element type of operand is %s.",
1544 PrimitiveType_Name(mean_shape.element_type()),
1545 PrimitiveType_Name(operand_shape.element_type()));
1546 }
1547
1548 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1549 operand_shape)) {
1550 return InvalidArgument(
1551 "The inputs should have the same element type for batch-norm-grad, "
1552 "but the element type of mean is %s "
1553 "and the element type of operand is %s.",
1554 PrimitiveType_Name(mean_shape.element_type()),
1555 PrimitiveType_Name(operand_shape.element_type()));
1556 }
1557
1558 const int64_t feature_count = operand_shape.dimensions(feature_index);
1559
1560 Shape feature_shape =
1561 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1562
1563 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1564 return InvalidArgument(
1565 "The size of mean should be the same as feature count,"
1566 "but the size of offset factor is %d "
1567 "and the feature count is %d.",
1568 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1569 }
1570
1571 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1572 return InvalidArgument(
1573 "The size of scale factor should be the same as feature count,"
1574 "but the size of scale factor is %d "
1575 "and the feature count is %d.",
1576 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1577 }
1578
1579 if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1580 return InvalidArgument(
1581 "The size of variance should be the same as feature count,"
1582 "but the size of variance is %d "
1583 "and the feature count is %d.",
1584 ShapeUtil::GetDimension(var_shape, 0), feature_count);
1585 }
1586
1587 // Verify operand_shape and output_grad_shape have same bounds.
1588 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
1589 if (ShapeUtil::GetDimension(operand_shape, i) !=
1590 ShapeUtil::GetDimension(output_grad_shape, i)) {
1591 return InvalidArgument(
1592 "The bounds of operand shape should be the same as output_grad's,"
1593 "but the bound of operand_shape at dimension %d is %d "
1594 "and the bound of output_grad_shape is %d.",
1595 i, ShapeUtil::GetDimension(operand_shape, i),
1596 ShapeUtil::GetDimension(output_grad_shape, i));
1597 }
1598 }
1599
1600 return ShapeUtil::MakeTupleShape(
1601 {operand_shape, feature_shape, feature_shape});
1602 }
1603
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums,absl::optional<PrimitiveType> preferred_element_type)1604 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1605 const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
1606 int64_t batch_group_count, const Window& window,
1607 const ConvolutionDimensionNumbers& dnums,
1608 absl::optional<PrimitiveType> preferred_element_type) {
1609 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1610 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1611
1612 if (feature_group_count <= 0) {
1613 return InvalidArgument(
1614 "feature_group_count must be a positive number, got %d",
1615 feature_group_count);
1616 }
1617
1618 if (batch_group_count <= 0) {
1619 return InvalidArgument(
1620 "batch_group_count must be a positive number, got %d",
1621 batch_group_count);
1622 }
1623
1624 if (batch_group_count > 1 && feature_group_count > 1) {
1625 return InvalidArgument(
1626 "both batch_group_count %d and feature_group_count %d cannot be "
1627 "greater than 1",
1628 batch_group_count, feature_group_count);
1629 }
1630
1631 if (dnums.input_spatial_dimensions_size() !=
1632 dnums.kernel_spatial_dimensions_size()) {
1633 return InvalidArgument(
1634 "Both arguments to convolution must have same number of dimensions.\n"
1635 "Numbers: %s",
1636 dnums.DebugString());
1637 }
1638
1639 if (dnums.input_spatial_dimensions_size() !=
1640 dnums.output_spatial_dimensions_size()) {
1641 return InvalidArgument(
1642 "Both input and output of convolution must have same number of "
1643 "dimensions.\nNumbers: %s",
1644 dnums.DebugString());
1645 }
1646
1647 const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1648 if (window.dimensions_size() != num_spatial_dims) {
1649 return InvalidArgument(
1650 "Window must have same number of dimensions as dimension numbers.\n"
1651 "Window: %s\nDimension numbers: %s.",
1652 window.DebugString(), dnums.DebugString());
1653 }
1654
1655 const int num_dims = num_spatial_dims + 2;
1656 if (lhs.rank() != num_dims) {
1657 return InvalidArgument(
1658 "The LHS argument to a convolution should have rank %d; lhs: %s.",
1659 num_dims, ShapeUtil::HumanString(lhs));
1660 }
1661 if (rhs.rank() != num_dims) {
1662 return InvalidArgument(
1663 "The RHS argument to a convolution should have rank %d; rhs: %s.",
1664 num_dims, ShapeUtil::HumanString(rhs));
1665 }
1666 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1667 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1668
1669 // Verifies that the input and window dimensions are a permutation of
1670 // the dimension numbers.
1671 std::vector<int64> input_dnums(num_dims);
1672 input_dnums[0] = dnums.input_batch_dimension();
1673 input_dnums[1] = dnums.input_feature_dimension();
1674 std::copy(dnums.input_spatial_dimensions().begin(),
1675 dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1676 absl::c_sort(input_dnums);
1677
1678 std::vector<int64> window_dnums(num_dims);
1679 window_dnums[0] = dnums.kernel_input_feature_dimension();
1680 window_dnums[1] = dnums.kernel_output_feature_dimension();
1681 std::copy(dnums.kernel_spatial_dimensions().begin(),
1682 dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1683 absl::c_sort(window_dnums);
1684
1685 std::vector<int64> output_dnums(num_dims);
1686 output_dnums[0] = dnums.output_batch_dimension();
1687 output_dnums[1] = dnums.output_feature_dimension();
1688 std::copy(dnums.output_spatial_dimensions().begin(),
1689 dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1690 absl::c_sort(output_dnums);
1691
1692 std::vector<int64> expected_dnums(num_dims);
1693 std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1694
1695 const auto in_range = [num_dims](int64_t i) {
1696 return 0 <= i && i < num_dims;
1697 };
1698 if (!absl::c_all_of(input_dnums, in_range) ||
1699 !absl::c_all_of(window_dnums, in_range) ||
1700 !absl::c_all_of(output_dnums, in_range)) {
1701 return InvalidArgument(
1702 "A dimension number is out of range in convolution: %s.",
1703 dnums.DebugString());
1704 }
1705
1706 if (input_dnums != expected_dnums) {
1707 return InvalidArgument(
1708 "Input dimensions of convolution must contain each dimension exactly "
1709 "once: %s.",
1710 dnums.DebugString());
1711 }
1712 if (window_dnums != expected_dnums) {
1713 return InvalidArgument(
1714 "Window dimensions of convolution must contain each dimension exactly "
1715 "once: %s.",
1716 dnums.DebugString());
1717 }
1718 if (output_dnums != expected_dnums) {
1719 return InvalidArgument(
1720 "Output dimensions of convolution must contain each dimension exactly "
1721 "once: %s.",
1722 dnums.DebugString());
1723 }
1724
1725 std::vector<int64> input_spatial_dims(num_spatial_dims);
1726 for (int i = 0; i < num_spatial_dims; ++i) {
1727 input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1728 }
1729 const int64_t input_features =
1730 lhs.dimensions(dnums.input_feature_dimension());
1731 const int64_t input_batch = lhs.dimensions(dnums.input_batch_dimension());
1732
1733 std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1734 for (int i = 0; i < num_spatial_dims; ++i) {
1735 kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1736 }
1737 const int64_t kernel_input_features =
1738 rhs.dimensions(dnums.kernel_input_feature_dimension());
1739 const int64_t kernel_output_features =
1740 rhs.dimensions(dnums.kernel_output_feature_dimension());
1741
1742 if (kernel_output_features % batch_group_count != 0) {
1743 return InvalidArgument(
1744 "Expected output feature dimension size (value %d) to be a multiple of "
1745 "batch group count %d; got <conv>(%s, %s)\n"
1746 "Dimension numbers: {%s}.",
1747 kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1748 ShapeUtil::HumanString(rhs), dnums.DebugString());
1749 }
1750
1751 if (input_features % feature_group_count != 0 ||
1752 input_features / feature_group_count != kernel_input_features) {
1753 return InvalidArgument(
1754 "Expected LHS feature dimension (value %d) to be a multiple of "
1755 "feature_group_count (value %d), and LHS feature dimension / "
1756 "feature_group_count = RHS feature dimension (value %d); "
1757 "got <conv>(%s, %s)\n"
1758 "Dimension numbers: {%s}.",
1759 input_features, feature_group_count, kernel_input_features,
1760 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1761 dnums.DebugString());
1762 }
1763
1764 if (kernel_output_features % feature_group_count > 0) {
1765 // A depthwise/grouped filter has the shape
1766 // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1767 // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1768 // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1769 // feature count (which is equal to kernel output features) has to be a
1770 // multiple of feature_group_count.
1771 return InvalidArgument(
1772 "Expected output feature dimension (value %d) to be divisible by "
1773 "feature_group_count (value %d); "
1774 "got <conv>(%s, %s)\n"
1775 "Dimension numbers: {%s}.",
1776 kernel_output_features, feature_group_count,
1777 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1778 dnums.DebugString());
1779 }
1780
1781 if (input_batch % batch_group_count != 0) {
1782 return InvalidArgument(
1783 "Expected input batch dimension (value %d) to be divisible by "
1784 "batch_group_count (value %d); "
1785 "got <conv>(%s, %s)\n"
1786 "Dimension numbers: {%s}.",
1787 input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1788 ShapeUtil::HumanString(rhs), dnums.DebugString());
1789 }
1790
1791 std::vector<int64> window_dims(num_spatial_dims);
1792 for (int i = 0; i < num_spatial_dims; ++i) {
1793 window_dims[i] = window.dimensions(i).size();
1794 }
1795 if (kernel_spatial_dims != window_dims) {
1796 return InvalidArgument(
1797 "Window dimensions do not match RHS shape:\n\t"
1798 "RHS shape: %s\n\t"
1799 "Window: {%s}\n\t"
1800 "Dimension numbers: {%s}.",
1801 ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1802 dnums.ShortDebugString());
1803 }
1804
1805 Shape base_shape =
1806 ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1807 TF_ASSIGN_OR_RETURN(
1808 Shape window_output_shape,
1809 InferWindowOutputShape(base_shape, window, lhs.element_type(),
1810 /*allow_negative_padding=*/true));
1811
1812 std::vector<int64> dimensions(num_dims);
1813 dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1814 dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1815
1816 for (int i = 0; i < num_spatial_dims; ++i) {
1817 dimensions[dnums.output_spatial_dimensions(i)] =
1818 window_output_shape.dimensions(i);
1819 }
1820 std::vector<bool> is_dynamic(num_dims);
1821 for (int i = 0; i < num_dims; i++) {
1822 if (lhs.is_dynamic_dimension(i)) {
1823 if (i == dnums.input_batch_dimension()) {
1824 is_dynamic[dnums.output_batch_dimension()] = true;
1825 } else if (i == dnums.input_feature_dimension()) {
1826 // Input feature dimension is a contracting dimension, which does not
1827 // affect the output dimension size. So we need to do nothing.
1828 } else {
1829 for (int64_t j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
1830 if (i == dnums.input_spatial_dimensions(j)) {
1831 // i is a spatial dimension, find corresponding output spatial
1832 // dimension.
1833 is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1834 }
1835 }
1836 }
1837 }
1838 if (rhs.is_dynamic_dimension(i)) {
1839 if (i == dnums.kernel_input_feature_dimension()) {
1840 // Kernel feature dimension does not affect the output dimension size.
1841 // So we need to do nothing.
1842 } else if (i == dnums.kernel_output_feature_dimension()) {
1843 return InvalidArgument(
1844 "Dynamic output feature dim on convolution kernel is not "
1845 "supported: rhs shape is %s ",
1846 rhs.ToString());
1847 } else {
1848 for (int64_t j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
1849 if (i == dnums.kernel_spatial_dimensions(j)) {
1850 // i is a spatial dimension, find corresponding output spatial
1851 // dimension.
1852 is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1853 }
1854 }
1855 }
1856 }
1857 }
1858 TF_ASSIGN_OR_RETURN(
1859 PrimitiveType type,
1860 MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1861 preferred_element_type));
1862 return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
1863 }
1864
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1865 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1866 const Shape& in, const FftType fft_type,
1867 const absl::Span<const int64> fft_length) {
1868 const int64_t fft_rank = fft_length.size();
1869 if (fft_rank < 1 || fft_rank > 3) {
1870 return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1871 }
1872 #define RET_CHECK_RANK(x) \
1873 if (x.dimensions_size() < fft_rank) { \
1874 return InvalidArgument( \
1875 "FFT of rank %d requires input of at least " \
1876 "same rank; got input of rank %d", \
1877 fft_rank, x.dimensions_size()); \
1878 }
1879 switch (fft_type) {
1880 case FFT:
1881 case IFFT:
1882 if (!primitive_util::IsComplexType(in.element_type())) {
1883 return InvalidArgument("%s requires complex input type, found %s.",
1884 FftType_Name(fft_type),
1885 PrimitiveType_Name(in.element_type()));
1886 }
1887 RET_CHECK_RANK(in);
1888 return in;
1889 case RFFT: {
1890 if (in.element_type() != F32 && in.element_type() != F64) {
1891 return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
1892 PrimitiveType_Name(in.element_type()));
1893 }
1894 RET_CHECK_RANK(in);
1895 for (int i = 0; i < fft_rank; i++) {
1896 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1897 fft_length[i]) {
1898 return InvalidArgument(
1899 "RFFT requires innermost dimensions match fft_length but "
1900 "dimension %d is %d and should be %d.",
1901 in.dimensions_size() - fft_rank + i,
1902 in.dimensions(in.dimensions_size() - fft_rank + i),
1903 fft_length[i]);
1904 }
1905 }
1906 Shape result = ShapeUtil::ChangeElementType(
1907 in, in.element_type() == F32 ? C64 : C128);
1908 // Preserve the size of zero-sized dimensions.
1909 if (fft_length[fft_rank - 1] != 0) {
1910 result.set_dimensions(result.dimensions_size() - 1,
1911 fft_length[fft_rank - 1] / 2 + 1);
1912 }
1913 return result;
1914 }
1915 case IRFFT: {
1916 if (!primitive_util::IsComplexType(in.element_type())) {
1917 return InvalidArgument("IRFFT requires complex input type, found %s.",
1918 PrimitiveType_Name(in.element_type()));
1919 }
1920 RET_CHECK_RANK(in);
1921 Shape result = ShapeUtil::ComplexComponentShape(in);
1922 for (int i = 0; i < fft_rank - 1; i++) {
1923 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1924 fft_length[i]) {
1925 return InvalidArgument(
1926 "IRFFT requires all but one innermost dimensions match "
1927 "fft_length, but dimension %d is %d and should be %d.",
1928 in.dimensions_size() - fft_rank + i,
1929 in.dimensions(in.dimensions_size() - fft_rank + i),
1930 fft_length[i]);
1931 }
1932 }
1933 // The size of zero-sized dimensions is preserved.
1934 if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1935 fft_length[fft_rank - 1] != 0) &&
1936 in.dimensions(in.dimensions_size() - 1) !=
1937 fft_length[fft_rank - 1] / 2 + 1) {
1938 return InvalidArgument(
1939 "IRFFT requires innermost dimension matches fft_length/2+1, but "
1940 "dimension %d is %d and should be %d.",
1941 in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1942 fft_length[fft_rank - 1] / 2 + 1);
1943 }
1944 result.set_dimensions(result.dimensions_size() - 1,
1945 fft_length[fft_rank - 1]);
1946 return result;
1947 }
1948 default:
1949 LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1950 }
1951 #undef RET_CHECK_RANK
1952 }
1953
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1954 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1955 const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1956 if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1957 a.element_type() != b.element_type()) {
1958 return InvalidArgument(
1959 "Expected element types in shape to be floating or complex and "
1960 "identical for TriangularSolve; got %s and %s.",
1961 PrimitiveType_Name(a.element_type()),
1962 PrimitiveType_Name(b.element_type()));
1963 }
1964 if (a.rank() < 2) {
1965 return InvalidArgument(
1966 "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1967 a.ToString());
1968 }
1969 if (b.rank() != a.rank()) {
1970 return InvalidArgument(
1971 "Arguments to triangular solve must have equal rank; got %s and %s.",
1972 b.ToString(), a.ToString());
1973 }
1974 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1975 return InvalidArgument(
1976 "The two minor dimensions of 'a' must have equal size, got %s.",
1977 a.ToString());
1978 }
1979 if (a.dimensions(a.rank() - 1) !=
1980 b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1981 return InvalidArgument(
1982 "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1983 "%s",
1984 a.ToString(), b.ToString());
1985 }
1986 absl::Span<const int64> a_batch_dims(a.dimensions());
1987 absl::Span<const int64> b_batch_dims(b.dimensions());
1988 a_batch_dims.remove_suffix(2);
1989 b_batch_dims.remove_suffix(2);
1990 if (a_batch_dims != b_batch_dims) {
1991 return InvalidArgument(
1992 "The leading batch dimensions of the arguments to triangular solve "
1993 "must be equal; got %s and %s.",
1994 b.ToString(), a.ToString());
1995 }
1996 if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
1997 options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
1998 return InvalidArgument(
1999 "Invalid transpose option value for triangular solve (%d).\n",
2000 options.transpose_a());
2001 }
2002 return b;
2003 }
2004
InferCholeskyShape(const Shape & a)2005 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
2006 const Shape& a) {
2007 if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
2008 return InvalidArgument(
2009 "Expected element type in shape to be floating or complex for "
2010 "Cholesky; got %s.",
2011 PrimitiveType_Name(a.element_type()));
2012 }
2013 if (a.rank() < 2) {
2014 return InvalidArgument(
2015 "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
2016 a.ToString());
2017 }
2018 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
2019 return InvalidArgument(
2020 "The two minor dimensions of 'a' must have equal size, got %s.",
2021 a.ToString());
2022 }
2023 return a;
2024 }
2025
InferAllGatherShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2026 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
2027 absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2028 int64_t shard_count) {
2029 TF_RET_CHECK(all_gather_dimension >= 0);
2030 TF_RET_CHECK(shard_count > 0);
2031
2032 std::vector<Shape> output_shapes;
2033 output_shapes.reserve(operand_shapes.size());
2034 for (const Shape* operand_shape : operand_shapes) {
2035 TF_RET_CHECK(all_gather_dimension < operand_shape->rank());
2036 TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather"));
2037
2038 Shape output_shape = *operand_shape;
2039 output_shape.set_dimensions(
2040 all_gather_dimension,
2041 shard_count * output_shape.dimensions(all_gather_dimension));
2042 output_shapes.push_back(output_shape);
2043 }
2044 if (output_shapes.size() == 1) {
2045 return output_shapes[0];
2046 }
2047 return ShapeUtil::MakeTupleShape(output_shapes);
2048 }
2049
InferAllGatherStartShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2050 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherStartShape(
2051 absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2052 int64_t shard_count) {
2053 TF_ASSIGN_OR_RETURN(
2054 Shape ag_shape,
2055 InferAllGatherShape(operand_shapes, all_gather_dimension, shard_count));
2056 Shape input_shape;
2057 std::vector<Shape> op_shapes;
2058 op_shapes.reserve(operand_shapes.size());
2059 for (const Shape* shp : operand_shapes) {
2060 op_shapes.push_back(*shp);
2061 }
2062 if (op_shapes.size() == 1) {
2063 input_shape = op_shapes[0];
2064 } else {
2065 input_shape = ShapeUtil::MakeTupleShape(op_shapes);
2066 }
2067 return ShapeUtil::MakeTupleShape({input_shape, ag_shape});
2068 }
2069
InferAllGatherDoneShape(const Shape & all_gather_start_shape)2070 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherDoneShape(
2071 const Shape& all_gather_start_shape) {
2072 return ShapeUtil::GetTupleElementShape(all_gather_start_shape, 1);
2073 }
2074
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2075 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2076 absl::Span<const Shape* const> operand_shapes) {
2077 for (const Shape* operand_shape : operand_shapes) {
2078 TF_RETURN_IF_ERROR(
2079 ExpectArray(*operand_shape, "operand of cross replica sum"));
2080 }
2081 if (operand_shapes.size() == 1) {
2082 return *operand_shapes[0];
2083 }
2084 std::vector<Shape> operand_shape_values;
2085 for (const Shape* operand_shape : operand_shapes) {
2086 operand_shape_values.push_back(*operand_shape);
2087 }
2088 return ShapeUtil::MakeTupleShape(operand_shape_values);
2089 }
2090
InferReduceScatterShape(absl::Span<const Shape * const> operand_shapes,int64_t scatter_dimension,int64_t shard_count)2091 /* static */ StatusOr<Shape> ShapeInference::InferReduceScatterShape(
2092 absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
2093 int64_t shard_count) {
2094 TF_RET_CHECK(scatter_dimension >= 0);
2095 TF_RET_CHECK(shard_count > 0);
2096
2097 std::vector<Shape> output_shapes;
2098 output_shapes.reserve(operand_shapes.size());
2099 for (const Shape* operand_shape : operand_shapes) {
2100 TF_RET_CHECK(scatter_dimension < operand_shape->rank());
2101 TF_RETURN_IF_ERROR(
2102 ExpectArray(*operand_shape, "operand of reduce-scatter"));
2103
2104 int64_t scatter_dim_input_size =
2105 operand_shape->dimensions(scatter_dimension);
2106 if (scatter_dim_input_size % shard_count != 0) {
2107 return InvalidArgument(
2108 "ReduceScatter operand scatter dimension size %d must be "
2109 "dividable by shard_count "
2110 "%d.",
2111 scatter_dim_input_size, shard_count);
2112 }
2113
2114 Shape output_shape = *operand_shape;
2115 output_shape.set_dimensions(scatter_dimension,
2116 scatter_dim_input_size / shard_count);
2117 output_shapes.push_back(output_shape);
2118 }
2119
2120 if (output_shapes.size() == 1) {
2121 return output_shapes[0];
2122 }
2123 return ShapeUtil::MakeTupleShape(output_shapes);
2124 }
2125
InferAllReduceStartShape(absl::Span<const Shape * const> operand_shapes)2126 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceStartShape(
2127 absl::Span<const Shape* const> operand_shapes) {
2128 TF_ASSIGN_OR_RETURN(Shape shape, InferAllReduceShape(operand_shapes));
2129
2130 return ShapeUtil::MakeTupleShape({shape, shape});
2131 }
2132
InferAllReduceDoneShape(const Shape & operand_shape)2133 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceDoneShape(
2134 const Shape& operand_shape) {
2135 // The returned value from AllReduceDone is determined from
2136 // HloDataflowAnalysis::UpdateAllReduceStartValueSet(). The operand to
2137 // AllReduceDone is a tuple of two elements and this function selects
2138 // ShapeIndex {1} as the value forwarded.
2139 return ShapeUtil::GetTupleElementShape(operand_shape, 1);
2140 }
2141
InferAllToAllShape(const Shape & shape,int64_t split_dimension,int64_t concat_dimension,int64_t split_count)2142 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2143 const Shape& shape, int64_t split_dimension, int64_t concat_dimension,
2144 int64_t split_count) {
2145 TF_RET_CHECK(split_count > 0);
2146 if (split_dimension >= shape.rank() || split_dimension < 0) {
2147 return InvalidArgument(
2148 "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2149 split_dimension, ShapeUtil::HumanString(shape));
2150 }
2151 if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2152 return InvalidArgument(
2153 "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2154 concat_dimension, ShapeUtil::HumanString(shape));
2155 }
2156 if (shape.dimensions(split_dimension) % split_count != 0) {
2157 return InvalidArgument(
2158 "AllToAll split dimension size %d must be dividable by split_count "
2159 "%d.",
2160 shape.dimensions(split_dimension), split_count);
2161 }
2162 std::vector<int64> new_dimensions(shape.dimensions().begin(),
2163 shape.dimensions().end());
2164 new_dimensions[split_dimension] /= split_count;
2165 new_dimensions[concat_dimension] *= split_count;
2166 return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2167 }
2168
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2169 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2170 absl::Span<const Shape* const> operand_shapes) {
2171 // An Alltoall HLO instruction receives N operands (with the same shape) and
2172 // returns a tuple that contains N array shapes.
2173 TF_RET_CHECK(!operand_shapes.empty());
2174 for (int i = 0; i < operand_shapes.size(); i++) {
2175 if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
2176 *operand_shapes[i])) {
2177 return InvalidArgument(
2178 "HLO all-to-all has operands with different shapes: the 0th "
2179 "operand shape %s, but the %dth operand has shape %s.",
2180 ShapeUtil::HumanString(*operand_shapes[0]), i,
2181 ShapeUtil::HumanString(*operand_shapes[i]));
2182 }
2183 }
2184
2185 return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2186 }
2187
InferCollectivePermuteShape(absl::Span<const Shape * const> operand_shapes)2188 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2189 absl::Span<const Shape* const> operand_shapes) {
2190 if (operand_shapes.size() == 1) {
2191 TF_RETURN_IF_ERROR(
2192 ExpectArray(*(operand_shapes[0]), "operand of collective-permute"));
2193 return *(operand_shapes[0]);
2194 } else {
2195 TF_RET_CHECK(operand_shapes.size() == 4);
2196 return *(operand_shapes[1]);
2197 }
2198 }
2199
InferCollectivePermuteStartShape(absl::Span<const Shape * const> operand_shapes)2200 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteStartShape(
2201 absl::Span<const Shape* const> operand_shapes) {
2202 if (operand_shapes.size() == 1) {
2203 TF_RETURN_IF_ERROR(ExpectArray(*(operand_shapes[0]),
2204 "operand of collective-permute-start"));
2205 return ShapeUtil::MakeTupleShape(
2206 {*(operand_shapes[0]), *(operand_shapes[0]),
2207 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
2208 } else {
2209 TF_RET_CHECK(operand_shapes.size() == 4);
2210 return ShapeUtil::MakeTupleShape(
2211 {*(operand_shapes[0]), *(operand_shapes[1]),
2212 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {}),
2213 ShapeUtil::MakeShape(S32, {})});
2214 }
2215 }
2216
InferCollectivePermuteDoneShape(const Shape & operand_shape)2217 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteDoneShape(
2218 const Shape& operand_shape) {
2219 TF_RET_CHECK(operand_shape.IsTuple());
2220 if (operand_shape.tuple_shapes_size() == 4) {
2221 return ShapeUtil::GetTupleElementShape(operand_shape, 0);
2222 } else {
2223 return ShapeUtil::GetTupleElementShape(operand_shape, 1);
2224 }
2225 }
2226
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2227 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2228 absl::Span<const Shape* const> arg_shapes,
2229 absl::Span<const int64> dimensions_to_reduce,
2230 const ProgramShape& to_apply) {
2231 if (arg_shapes.empty()) {
2232 return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2233 }
2234 if (arg_shapes.size() % 2) {
2235 return InvalidArgument(
2236 "Reduce must have an even number of arguments, has %lu",
2237 arg_shapes.size());
2238 }
2239 int64_t num_reduced_args = arg_shapes.size() / 2;
2240 auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2241 // Check that all of the reduced tensors have the same dimensions. The element
2242 // types may be different.
2243 for (int64_t i = 1; i < num_reduced_args; ++i) {
2244 if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2245 return InvalidArgument(
2246 "All reduced tensors must have the same dimension. Tensor 0 has "
2247 "shape %s, Tensor %d has shape %s",
2248 ShapeUtil::HumanString(*reduced_args[0]), i,
2249 ShapeUtil::HumanString(*reduced_args[i]));
2250 }
2251 }
2252 // Check that the dimensions to reduce are in-bounds for the given shape.
2253 // We've already verified all reduced tensors have the same dimensions, so it
2254 // doesn't matter which one we choose.
2255 const Shape& arg = *reduced_args[0];
2256 for (int64_t dimension : dimensions_to_reduce) {
2257 if (dimension >= arg.rank() || dimension < 0) {
2258 return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2259 dimension, ShapeUtil::HumanString(arg));
2260 }
2261 }
2262
2263 auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2264 std::vector<PrimitiveType> element_types;
2265 for (const Shape* arg : reduced_args) {
2266 element_types.push_back(arg->element_type());
2267 }
2268 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2269 num_reduced_args));
2270
2271 absl::flat_hash_set<int64> dimensions_to_reduce_set;
2272 for (int64_t dim_to_reduce : dimensions_to_reduce) {
2273 if (!dimensions_to_reduce_set.insert(dim_to_reduce).second) {
2274 return InvalidArgument("Duplicate reduction dimension: %d",
2275 dim_to_reduce);
2276 }
2277 }
2278
2279 std::vector<int64> new_dimensions;
2280 std::vector<bool> new_is_dynamic;
2281 for (int i = 0; i < arg.rank(); ++i) {
2282 if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2283 new_dimensions.push_back(arg.dimensions(i));
2284 new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2285 }
2286 }
2287
2288 if (ShapeUtil::IsScalar(to_apply.result())) {
2289 return ShapeUtil::MakeShape(to_apply.result().element_type(),
2290 new_dimensions, new_is_dynamic);
2291 } else {
2292 std::vector<Shape> result_subshapes;
2293 for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2294 result_subshapes.push_back(ShapeUtil::MakeShape(
2295 subshape.element_type(), new_dimensions, new_is_dynamic));
2296 }
2297 return ShapeUtil::MakeTupleShape(result_subshapes);
2298 }
2299 }
2300
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2301 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2302 const Shape& operand_shape, const Shape& init_value_shape,
2303 const Window& window, const ProgramShape& to_apply_shape) {
2304 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2305 {operand_shape.element_type()},
2306 /*inputs=*/1));
2307 return InferReduceWindowShape(operand_shape, init_value_shape, window);
2308 }
2309
InferReduceWindowShape(absl::Span<const Shape * const> operands,absl::Span<const Shape * const> init_values,const Window & window,const ProgramShape & to_apply_shape)2310 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2311 absl::Span<const Shape* const> operands,
2312 absl::Span<const Shape* const> init_values, const Window& window,
2313 const ProgramShape& to_apply_shape) {
2314 auto number_of_input = operands.size();
2315 // Check that all of the reduced tensors have the same dimensions. The element
2316 // types may be different.
2317 for (int64_t i = 1; i < number_of_input; ++i) {
2318 if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
2319 return InvalidArgument(
2320 "All reduced tensors must have the same dimension. Tensor 0 has "
2321 "shape %s, Tensor %d has shape %s",
2322 ShapeUtil::HumanString(*operands[0]), i,
2323 ShapeUtil::HumanString(*operands[i]));
2324 }
2325 }
2326 std::vector<PrimitiveType> operand_element_type_vec;
2327 for (const Shape* s : operands) {
2328 operand_element_type_vec.push_back(s->element_type());
2329 }
2330 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
2331 operand_element_type_vec,
2332 /*inputs=*/number_of_input));
2333 std::vector<Shape> output_shape_vec;
2334 for (int i = 0; i < operands.size(); ++i) {
2335 TF_ASSIGN_OR_RETURN(
2336 auto cur_output_shape,
2337 InferReduceWindowShape(*operands[i], *init_values[i], window));
2338 output_shape_vec.push_back(cur_output_shape);
2339 }
2340 if (ShapeUtil::IsScalar(to_apply_shape.result())) {
2341 CHECK_EQ(output_shape_vec.size(), 1);
2342 return output_shape_vec[0];
2343 } else {
2344 return ShapeUtil::MakeTupleShape(output_shape_vec);
2345 }
2346 }
2347
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2348 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2349 const Shape& operand_shape, const Shape& init_value_shape,
2350 const Window& window) {
2351 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2352 return InferWindowOutputShape(operand_shape, window,
2353 init_value_shape.element_type(),
2354 /*allow_negative_padding=*/false);
2355 }
2356
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2357 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2358 const Shape& operand_shape, const ProgramShape& select_shape,
2359 const Window& window, const Shape& source_shape,
2360 const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2361 TF_RETURN_IF_ERROR(
2362 ExpectArray(operand_shape, "operand of select-and-scatter"));
2363
2364 // Check if the select function has a proper shape of (T,T) -> PRED.
2365 if (select_shape.parameters_size() != 2) {
2366 return InvalidArgument(
2367 "Select function must take 2 parameters, but "
2368 "takes %d parameter(s).",
2369 select_shape.parameters_size());
2370 }
2371 const Shape& select_result_shape = select_shape.result();
2372 if (!ShapeUtil::Compatible(select_result_shape,
2373 ShapeUtil::MakeShape(PRED, {}))) {
2374 return InvalidArgument("Select function must have rank-0 PRED result.");
2375 }
2376 const Shape& operand_element_shape =
2377 ShapeUtil::MakeShape(operand_shape.element_type(), {});
2378 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2379 select_shape.parameters(0))) {
2380 return InvalidArgument(
2381 "Select function's first parameter shape currently must "
2382 "match the operand element shape, but got %s vs %s.",
2383 ShapeUtil::HumanString(select_shape.parameters(0)),
2384 ShapeUtil::HumanString(operand_element_shape));
2385 }
2386 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2387 select_shape.parameters(1))) {
2388 return InvalidArgument(
2389 "Select function's second parameter shape currently must "
2390 "match the operand element shape, but got %s vs %s.",
2391 ShapeUtil::HumanString(select_shape.parameters(1)),
2392 ShapeUtil::HumanString(operand_element_shape));
2393 }
2394
2395 // Check if the scatter function has a proper shape as a reduction.
2396 TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2397 {source_shape.element_type()},
2398 /*inputs=*/1));
2399
2400 // Check if the result shape of window operation matches the source shape.
2401 TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2402 InferWindowOutputShape(operand_shape, window,
2403 operand_shape.element_type(),
2404 /*allow_negative_padding=*/false));
2405 if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2406 window_result_shape)) {
2407 return InvalidArgument(
2408 "Source shape does not match the shape of window-reduced operand: "
2409 "source(%s), window-reduced operand(%s).",
2410 ShapeUtil::HumanString(source_shape),
2411 ShapeUtil::HumanString(window_result_shape));
2412 }
2413
2414 return operand_shape;
2415 }
2416
InferGetDimensionSizeShape(const Shape & shape,int64_t dimension)2417 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2418 const Shape& shape, int64_t dimension) {
2419 if (dimension < 0 || dimension >= shape.rank()) {
2420 return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2421 dimension);
2422 }
2423
2424 // TODO(b/119580730): Remove this restriction when very large dimension size
2425 // is needed.
2426 if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2427 return InvalidArgument(
2428 "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2429 "INT_MAX limit.",
2430 ShapeUtil::HumanString(shape), dimension);
2431 }
2432
2433 return ShapeUtil::MakeShape(S32, {});
2434 }
2435
InferSetDimensionSizeShape(const Shape & shape,const Shape & val_shape,int64_t dimension)2436 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2437 const Shape& shape, const Shape& val_shape, int64_t dimension) {
2438 if (dimension < 0 || dimension >= shape.rank()) {
2439 return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2440 dimension);
2441 }
2442
2443 if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
2444 return InvalidArgument(
2445 "SetDimensionSize's value has to be S32 scalar, got %s",
2446 val_shape.ToString());
2447 }
2448 // TODO(b/119580730): Remove this restriction when very large dimension size
2449 // is needed.
2450 if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2451 return InvalidArgument(
2452 "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2453 "INT_MAX limit.",
2454 ShapeUtil::HumanString(shape), dimension);
2455 }
2456
2457 Shape result = shape;
2458 result.set_dynamic_dimension(dimension, true);
2459 return result;
2460 }
2461
InferWindowFromDimensions(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation)2462 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2463 absl::Span<const int64> window_dimensions,
2464 absl::Span<const int64> window_strides,
2465 absl::Span<const std::pair<int64, int64>> padding,
2466 absl::Span<const int64> lhs_dilation,
2467 absl::Span<const int64> rhs_dilation) {
2468 const auto verify_size = [&](const size_t x, const char* x_name) {
2469 if (x == 0 || x == window_dimensions.size()) {
2470 return Status::OK();
2471 } else {
2472 return InvalidArgument(
2473 "%s", absl::StrCat(
2474 "Window has different number of window dimensions than of ",
2475 x_name,
2476 "\nNumber of window dimensions: ", window_dimensions.size(),
2477 "\nNumber of ", x_name, ": ", x, "\n"));
2478 }
2479 };
2480 TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2481 TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2482 TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2483 TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2484
2485 Window window;
2486 for (size_t i = 0; i < window_dimensions.size(); i++) {
2487 auto dim = window.add_dimensions();
2488 dim->set_size(window_dimensions[i]);
2489 if (!window_strides.empty()) {
2490 dim->set_stride(window_strides[i]);
2491 } else {
2492 dim->set_stride(1);
2493 }
2494 if (!padding.empty()) {
2495 dim->set_padding_low(padding[i].first);
2496 dim->set_padding_high(padding[i].second);
2497 } else {
2498 dim->set_padding_low(0);
2499 dim->set_padding_high(0);
2500 }
2501 if (!lhs_dilation.empty()) {
2502 dim->set_base_dilation(lhs_dilation[i]);
2503 } else {
2504 dim->set_base_dilation(1);
2505 }
2506 if (!rhs_dilation.empty()) {
2507 dim->set_window_dilation(rhs_dilation[i]);
2508 } else {
2509 dim->set_window_dilation(1);
2510 }
2511 dim->set_window_reversal(false);
2512 }
2513 return window;
2514 }
2515
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2516 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2517 const Shape& arg, absl::Span<const int64> starts,
2518 absl::Span<const int64> limits, absl::Span<const int64> strides) {
2519 auto error = [&](const string& message) {
2520 return InvalidArgument(
2521 "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2522 "{%s}; strides: {%s}.",
2523 message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2524 StrJoin(limits, ","), StrJoin(strides, ","));
2525 };
2526 TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2527 VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2528 ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2529 StrJoin(limits, ", "));
2530
2531 if (starts.size() != limits.size()) {
2532 return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2533 starts.size(), limits.size()));
2534 }
2535
2536 if (starts.size() != strides.size()) {
2537 return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2538 starts.size(), strides.size()));
2539 }
2540
2541 if (starts.size() != arg.rank()) {
2542 return InvalidArgument(
2543 "Slice index count does not match argument rank: %u vs %d.",
2544 starts.size(), arg.rank());
2545 }
2546
2547 std::vector<int64> sizes;
2548 for (int64_t dimension = 0; dimension < starts.size(); ++dimension) {
2549 int64_t start_index = starts[dimension];
2550 int64_t limit_index = limits[dimension];
2551 int64_t stride = strides[dimension];
2552 if (start_index < 0) {
2553 return InvalidArgument("Negative start index to slice: %d.", start_index);
2554 }
2555 if (limit_index > arg.dimensions(dimension)) {
2556 return error(
2557 StrFormat("limit index (%d) must be less than or equal to dimension "
2558 "size (%d)",
2559 limit_index, arg.dimensions(dimension)));
2560 }
2561 VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2562 VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2563 if (start_index > limit_index) {
2564 return error(
2565 StrFormat("limit index (%d) must be greater or equal to "
2566 "start index (%d) in slice with positive stride",
2567 limit_index, start_index));
2568 }
2569 if (stride <= 0) {
2570 return InvalidArgument("Stride (%d) must be positive.", stride);
2571 }
2572 sizes.push_back((limit_index - start_index + stride - 1) / stride);
2573 }
2574
2575 std::vector<bool> is_dynamic(arg.rank());
2576 for (int64_t i = 0; i < arg.dimensions_size(); ++i) {
2577 // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2578 if (sizes[i] == 1) {
2579 continue;
2580 }
2581 is_dynamic[i] = arg.is_dynamic_dimension(i);
2582 }
2583
2584 return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2585 }
2586
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2587 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2588 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2589 absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2590 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2591 auto number_of_indices = start_index_shapes.size();
2592 // TODO(b/118437727): Remove this path.
2593 if (!allow_scalar_indices ||
2594 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2595 if (number_of_indices != 1) {
2596 return InvalidArgument(
2597 "Dynamic slice should have exactly 1 index operand, has %d.",
2598 number_of_indices);
2599 }
2600
2601 const Shape& start_indices_shape = start_index_shapes[0];
2602 VLOG(2) << StrFormat(
2603 "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2604 ShapeUtil::HumanString(operand_shape),
2605 ShapeUtil::HumanString(start_indices_shape),
2606 StrJoin(slice_sizes, ", "));
2607
2608 TF_RETURN_IF_ERROR(
2609 ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2610
2611 if (start_indices_shape.rank() != 1) {
2612 return InvalidArgument(
2613 "Dynamic slice start indices of rank %d must be rank1.",
2614 start_indices_shape.rank());
2615 }
2616
2617 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2618 return InvalidArgument(
2619 "Dynamic slice start indices must be of integral type.");
2620 }
2621
2622 const int64_t start_num_dims = start_indices_shape.dimensions(0);
2623 if (operand_shape.rank() != start_num_dims) {
2624 return InvalidArgument(
2625 "Dynamic slice start number of dimensions %d (%s) must match rank "
2626 "%d of slice input (%s).",
2627 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2628 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2629 }
2630 } else {
2631 VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2632 ShapeUtil::HumanString(operand_shape),
2633 StrJoin(slice_sizes, ", "));
2634
2635 if (operand_shape.rank() != number_of_indices) {
2636 return InvalidArgument(
2637 "Dynamic slice start number of dimensions %d must match rank "
2638 "%d of slice input (%s).",
2639 number_of_indices, operand_shape.rank(),
2640 ShapeUtil::HumanString(operand_shape));
2641 }
2642
2643 if (number_of_indices > 0) {
2644 const Shape& first_index_shape = start_index_shapes[0];
2645 if (!ShapeUtil::IsScalar(first_index_shape)) {
2646 return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2647 ShapeUtil::HumanString(first_index_shape));
2648 }
2649 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2650 return InvalidArgument(
2651 "Dynamic slice start indices must be of integral type.");
2652 }
2653 for (const Shape& index_shape : start_index_shapes) {
2654 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2655 return InvalidArgument(
2656 "Dynamic slice start indices must all have the same shape, got "
2657 "mismatching indices with shapes %s and %s.",
2658 ShapeUtil::HumanString(first_index_shape),
2659 ShapeUtil::HumanString(index_shape));
2660 }
2661 }
2662 }
2663 }
2664
2665 if (slice_sizes.size() != operand_shape.rank()) {
2666 return InvalidArgument(
2667 "Dynamic slice index count does not match argument rank: %u vs %d.",
2668 slice_sizes.size(), operand_shape.rank());
2669 }
2670
2671 for (int64_t dim = 0; dim < slice_sizes.size(); ++dim) {
2672 const int64_t input_dim_size = operand_shape.dimensions(dim);
2673 const int64_t slice_dim_size = slice_sizes[dim];
2674 if (slice_dim_size < 0) {
2675 return InvalidArgument("Negative size index to dynamic slice: %d.",
2676 slice_dim_size);
2677 }
2678 if (slice_dim_size > input_dim_size) {
2679 return InvalidArgument(
2680 "Slice dim size %d greater than dynamic slice dimension: %d.",
2681 slice_dim_size, input_dim_size);
2682 }
2683 VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2684 }
2685
2686 return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2687 }
2688
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2689 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2690 const Shape& operand_shape, const Shape& update_shape,
2691 absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2692 TF_RETURN_IF_ERROR(
2693 ExpectArray(operand_shape, "operand of dynamic update slice"));
2694 TF_RETURN_IF_ERROR(
2695 ExpectArray(update_shape, "update of dynamic update slice"));
2696
2697 auto number_of_indices = start_index_shapes.size();
2698 // TODO(b/118437727): Remove this path.
2699 if (!allow_scalar_indices ||
2700 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2701 if (number_of_indices != 1) {
2702 return InvalidArgument(
2703 "Dynamic update slice should have exactly 1 index operand, has %d.",
2704 number_of_indices);
2705 }
2706 const Shape& start_indices_shape = start_index_shapes[0];
2707 TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2708 "start indices of dynamic update slice"));
2709
2710 VLOG(2) << StrFormat(
2711 "updating slice of shape %s at dynamic start_indices %s with update "
2712 "shape %s",
2713 ShapeUtil::HumanString(operand_shape),
2714 ShapeUtil::HumanString(start_indices_shape),
2715 ShapeUtil::HumanString(update_shape));
2716
2717 if (start_indices_shape.rank() != 1) {
2718 return InvalidArgument(
2719 "Dynamic update slice start indices of rank %d must be rank1.",
2720 start_indices_shape.rank());
2721 }
2722
2723 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2724 return InvalidArgument(
2725 "Dynamic update slice start indices must be of integral type.");
2726 }
2727
2728 const int64_t start_num_dims = start_indices_shape.dimensions(0);
2729 if (operand_shape.rank() != start_num_dims) {
2730 return InvalidArgument(
2731 "Dynamic update slice start number of dimensions %d (%s) must match "
2732 "rank %d of slice input (%s).",
2733 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2734 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2735 }
2736 } else {
2737 VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2738 ShapeUtil::HumanString(operand_shape),
2739 ShapeUtil::HumanString(update_shape));
2740
2741 if (operand_shape.rank() != number_of_indices) {
2742 return InvalidArgument(
2743 "Dynamic update slice start number of dimensions %d must match "
2744 "rank %d of slice input (%s).",
2745 number_of_indices, operand_shape.rank(),
2746 ShapeUtil::HumanString(operand_shape));
2747 }
2748
2749 if (number_of_indices > 0) {
2750 const Shape& first_index_shape = start_index_shapes[0];
2751 if (!ShapeUtil::IsScalar(first_index_shape)) {
2752 return InvalidArgument(
2753 "Dynamic update slice indices must be scalar, not %s.",
2754 ShapeUtil::HumanString(first_index_shape));
2755 }
2756 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2757 return InvalidArgument(
2758 "Dynamic update slice start indices must be of integral type.");
2759 }
2760 for (const Shape& index_shape : start_index_shapes) {
2761 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2762 return InvalidArgument(
2763 "Dynamic update slice start indices must all have the same "
2764 "shape, got mismatching indices with shapes %s and %s.",
2765 ShapeUtil::HumanString(first_index_shape),
2766 ShapeUtil::HumanString(index_shape));
2767 }
2768 }
2769 }
2770 }
2771
2772 if (update_shape.rank() != operand_shape.rank()) {
2773 return InvalidArgument(
2774 "Dynamic update slice update rank does not match argument rank: "
2775 "%d vs %d.",
2776 update_shape.rank(), operand_shape.rank());
2777 }
2778
2779 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2780 update_shape)) {
2781 return InvalidArgument(
2782 "Dynamic update slice update element type does not match argument. "
2783 "operand.element_type: %s vs update.element_type: %s.",
2784 PrimitiveType_Name(operand_shape.element_type()),
2785 PrimitiveType_Name(update_shape.element_type()));
2786 }
2787
2788 for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) {
2789 const int64_t input_dim_size = operand_shape.dimensions(dim);
2790 const int64_t update_dim_size = update_shape.dimensions(dim);
2791 if (update_dim_size < 0) {
2792 return InvalidArgument(
2793 "Size index %d to dynamic update slice must be >= 0.",
2794 update_dim_size);
2795 }
2796 if (update_dim_size > input_dim_size) {
2797 return InvalidArgument(
2798 "Update dim size %d greater than dynamic slice dimension: %d.",
2799 update_dim_size, input_dim_size);
2800 }
2801 VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2802 }
2803
2804 auto result_shape = operand_shape;
2805
2806 // If any of the operand shape is dynamic, the result dimension is also
2807 // dynamic.
2808 // If update shape is dynamic, only propagate dynamic dimension to result if
2809 // the update is a full update (update_shape[i] == operand_shape[i]).
2810 for (int64_t i = 0; i < update_shape.rank(); ++i) {
2811 if (operand_shape.is_dynamic_dimension(i)) {
2812 result_shape.set_dynamic_dimension(i, true);
2813 }
2814
2815 if (update_shape.is_dynamic_dimension(i) &&
2816 update_shape.dimensions(i) == operand_shape.dimensions(i)) {
2817 // When update/replace a full dimension, propagate dynamic dimension to
2818 // the result.
2819 result_shape.set_dynamic_dimension(i, true);
2820 }
2821 }
2822
2823 return result_shape;
2824 }
2825
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2826 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2827 const Shape& operand_shape, absl::Span<const int64> dimensions) {
2828 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2829 if (!AllUnique(dimensions)) {
2830 return InvalidArgument("a dimension number is duplicated in reverse");
2831 }
2832 for (int64_t dimension : dimensions) {
2833 if (dimension >= operand_shape.rank() || dimension < 0) {
2834 return InvalidArgument(
2835 "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2836 dimension, ShapeUtil::HumanString(operand_shape));
2837 }
2838 }
2839 return operand_shape;
2840 }
2841
InferGetTupleElementShape(const Shape & arg,int64_t index)2842 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2843 const Shape& arg, int64_t index) {
2844 if (!arg.IsTuple()) {
2845 return InvalidArgument(
2846 "Cannot infer shape: attempting to index into non-tuple: %s.",
2847 ShapeUtil::HumanString(arg));
2848 }
2849
2850 if (index < 0 || index >= arg.tuple_shapes_size()) {
2851 return InvalidArgument(
2852 "Cannot infer shape: attempt to index out of tuple bounds: %d "
2853 ">= %d in shape %s.",
2854 index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2855 }
2856
2857 return arg.tuple_shapes(index);
2858 }
2859
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2860 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2861 const ProgramShape& condition, const ProgramShape& body,
2862 const Shape& init) {
2863 // Check the number of parameters for given computations.
2864 if (condition.parameters_size() != 1) {
2865 return InvalidArgument("Condition must take 1 arguments; got %d.",
2866 condition.parameters_size());
2867 }
2868 if (body.parameters_size() != 1) {
2869 return InvalidArgument("Body must take 1 arguments; got %d.",
2870 body.parameters_size());
2871 }
2872
2873 auto shape_string = [&]() {
2874 return StrFormat(
2875 "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2876 ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2877 };
2878
2879 // Check the shapes of computation parameters and return types.
2880 if (!ShapeUtil::Compatible(condition.result(),
2881 ShapeUtil::MakeShape(PRED, {}))) {
2882 return InvalidArgument("Condition must return a boolean; got %s.",
2883 shape_string());
2884 }
2885 if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2886 !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2887 !ShapeUtil::Compatible(body.result(), init)) {
2888 return InvalidArgument(
2889 "The parameter of condition and body, the result of the body, and init "
2890 "must all have the same shape; got %s.",
2891 shape_string());
2892 }
2893
2894 return init;
2895 }
2896
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2897 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2898 const Shape& branch_index,
2899 absl::Span<const ProgramShape> branch_computations,
2900 absl::Span<const Shape> branch_operands) {
2901 if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2902 !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2903 return InvalidArgument("branch_index must be bool or int32; got %s.",
2904 ShapeUtil::HumanString(branch_index));
2905 }
2906 if (branch_index.element_type() == PRED) {
2907 TF_RET_CHECK(2 == branch_computations.size());
2908 } else {
2909 TF_RET_CHECK(!branch_computations.empty());
2910 }
2911 TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2912 Shape result = branch_computations[0].result();
2913 for (int j = 0; j < branch_computations.size(); ++j) {
2914 if (branch_computations[j].parameters_size() != 1) {
2915 return InvalidArgument(
2916 "branch computation %d must take 1 argument; got %d.", j,
2917 branch_computations[j].parameters_size());
2918 }
2919 if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2920 branch_operands[j])) {
2921 auto shape_string = [&]() {
2922 return StrFormat("operand: %s; computation: %s",
2923 ShapeUtil::HumanString(branch_operands[j]),
2924 ShapeUtil::HumanString(branch_computations[j]));
2925 };
2926 return InvalidArgument(
2927 "branch operand %d must match the shape of the only parameter of "
2928 "branch computation %d: got %s.",
2929 j, j, shape_string());
2930 }
2931
2932 if (!ShapeUtil::Compatible(branch_computations[0].result(),
2933 branch_computations[j].result())) {
2934 auto shape_string = [&]() {
2935 return StrFormat(
2936 "branch 0 computation result: %s; branch %d computation result: %s",
2937 ShapeUtil::HumanString(branch_computations[0].result()), j,
2938 ShapeUtil::HumanString(branch_computations[j].result()));
2939 };
2940 return InvalidArgument(
2941 "the result of branch 0 computation and branch %d computation must "
2942 "have the same shape: got %s.",
2943 j, shape_string());
2944 }
2945 }
2946 // For each subshape, If any of the branch is dynamic, we say result is
2947 // dynamic:
2948 //
2949 // true_branch (s32[<=4])
2950 // false_branch (s32[4])
2951 //
2952 // Result is s32[<=4].
2953 ShapeUtil::ForEachMutableSubshape(
2954 &result, [&](Shape* subshape, const ShapeIndex& index) {
2955 if (!subshape->IsArray()) {
2956 return;
2957 }
2958 for (int j = 0; j < branch_computations.size(); ++j) {
2959 auto branch_subshape =
2960 ShapeUtil::GetSubshape(branch_computations[j].result(), index);
2961 for (int64 i = 0; i < branch_subshape.rank(); ++i) {
2962 if (branch_subshape.is_dynamic_dimension(i)) {
2963 subshape->set_dynamic_dimension(i, true);
2964 }
2965 }
2966 }
2967 });
2968
2969 return result;
2970 }
2971
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2972 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2973 const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2974 TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2975 for (int64_t size : broadcast_sizes) {
2976 if (size < 0) {
2977 return InvalidArgument("Broadcast with negative dimension size %d.",
2978 size);
2979 }
2980 }
2981
2982 std::vector<int64> dimensions(operand.dimensions_size() +
2983 broadcast_sizes.size());
2984 std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2985 std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2986 dimensions.begin() + broadcast_sizes.size());
2987
2988 Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2989 for (int64_t i = 0; i < operand.dimensions_size(); ++i) {
2990 result.set_dynamic_dimension(broadcast_sizes.size() + i,
2991 operand.is_dynamic_dimension(i));
2992 }
2993 return result;
2994 }
2995
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2996 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2997 const Shape& operand_shape, const Shape& output_shape,
2998 absl::Span<const int64> broadcast_dimensions) {
2999 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
3000 TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
3001 const int64_t operand_rank = operand_shape.rank();
3002 const int64_t output_rank = output_shape.rank();
3003 if (operand_rank > output_rank) {
3004 return InvalidArgument(
3005 "InDim style broadcast must be to an equal or higher ranked shape; "
3006 "operand rank: %lld; output rank: %lld",
3007 operand_rank, output_rank);
3008 }
3009 if (operand_rank != broadcast_dimensions.size()) {
3010 return InvalidArgument(
3011 "Size of broadcast_dimensions has to match operand's rank; operand "
3012 "rank: %lld, size of broadcast_dimensions %u.",
3013 operand_rank, broadcast_dimensions.size());
3014 }
3015 for (int64_t i = 0; i < operand_rank; i++) {
3016 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
3017 return InvalidArgument("Broadcast dimension %lld is out of bound",
3018 broadcast_dimensions[i]);
3019 }
3020 if (operand_shape.dimensions(i) !=
3021 output_shape.dimensions(broadcast_dimensions[i]) &&
3022 operand_shape.dimensions(i) != 1) {
3023 return InvalidArgument(
3024 "Input dimension should be either 1 or equal to the output dimension "
3025 "it is broadcasting into; the %lldth operand dimension is %lld, the "
3026 "%lldth output dimension is %lld.",
3027 i, operand_shape.dimensions(i), broadcast_dimensions[i],
3028 output_shape.dimensions(broadcast_dimensions[i]));
3029 }
3030 if (operand_shape.is_dynamic_dimension(i) !=
3031 output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
3032 return InvalidArgument(
3033 "Broadcast input and output dynamism mismatch: %s and %s",
3034 operand_shape.ToString(), output_shape.ToString());
3035 }
3036 // Make sure the broadcast dimensions are listed in a strictly increasing
3037 // order.
3038 if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
3039 return InvalidArgument(
3040 "Broadcast dimensions order is wrong: %d comes after %d.",
3041 broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
3042 }
3043 }
3044
3045 return output_shape;
3046 }
3047
InferDynamicReshapeShape(const Shape & operand,absl::Span<const Shape * const> dim_size_shapes,absl::Span<const int64> new_size_bounds,const std::vector<bool> & dims_are_dynamic)3048 /* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
3049 const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
3050 absl::Span<const int64> new_size_bounds,
3051 const std::vector<bool>& dims_are_dynamic) {
3052 if (new_size_bounds.size() != dims_are_dynamic.size()) {
3053 return InvalidArgument(
3054 "DynamicReshape has to have the same number of elements in new_sizes "
3055 "(%d) and dims_are_dynamic (%d)",
3056 new_size_bounds.size(), dims_are_dynamic.size());
3057 }
3058
3059 for (const Shape* dim_size_shape : dim_size_shapes) {
3060 if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
3061 return InvalidArgument(
3062 "DynamicReshape's dim size has to be scalar S32, got (%s): ",
3063 dim_size_shape->ToString());
3064 }
3065 }
3066
3067 Shape inferred_shape = ShapeUtil::MakeShape(
3068 operand.element_type(), new_size_bounds, dims_are_dynamic);
3069 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3070 return InvalidArgument(
3071 "Reshape operation has mismatched element counts: from=%d (%s) "
3072 "to=%d (%s).",
3073 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3074 ShapeUtil::ElementsIn(inferred_shape),
3075 ShapeUtil::HumanString(inferred_shape));
3076 }
3077 return inferred_shape;
3078 }
3079
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64_t inferred_dimension)3080 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
3081 const Shape& operand, absl::Span<const int64> dimensions,
3082 absl::Span<const int64> new_sizes, int64_t inferred_dimension) {
3083 TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
3084
3085 Shape inferred_shape =
3086 ShapeUtil::MakeShape(operand.element_type(), new_sizes);
3087 VLOG(3) << "Reshape inferred shape: "
3088 << ShapeUtil::HumanString(inferred_shape);
3089
3090 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3091 return InvalidArgument(
3092 "Reshape operation has mismatched element counts: from=%d (%s) "
3093 "to=%d (%s).",
3094 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3095 ShapeUtil::ElementsIn(inferred_shape),
3096 ShapeUtil::HumanString(inferred_shape));
3097 }
3098
3099 std::vector<int64> indices(operand.rank());
3100 std::iota(indices.begin(), indices.end(), 0);
3101 if (dimensions.size() != operand.rank() ||
3102 !std::is_permutation(dimensions.begin(), dimensions.end(),
3103 indices.begin())) {
3104 return InvalidArgument(
3105 "Reshape dimensions [%s] are not a permutation of the operand "
3106 "dimensions (operand shape is %s).",
3107 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3108 }
3109
3110 // Propagate dynamic dimension.
3111 auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
3112 for (int64_t input_dim = 0; input_dim < operand.rank(); ++input_dim) {
3113 if (!operand.is_dynamic_dimension(input_dim)) {
3114 continue;
3115 }
3116
3117 string reshape_debug_str = absl::StrFormat(
3118 "output: %s, input: %s, input_dim: "
3119 "%lld",
3120 ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
3121 input_dim);
3122
3123 int64_t input_dim_start = -1;
3124 int64_t input_dim_end = -1;
3125 int64_t output_dim_start = -1;
3126 int64_t output_dim_end = -1;
3127 // Find common_factors that the input_dim belongs to.
3128 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
3129 auto start = common_factors[i];
3130 auto end = common_factors[i + 1];
3131 if (input_dim >= start.first && input_dim < end.first) {
3132 input_dim_start = start.first;
3133 input_dim_end = end.first;
3134 output_dim_start = start.second;
3135 output_dim_end = end.second;
3136 break;
3137 }
3138 }
3139 if ((input_dim_end - input_dim_start) > 1 &&
3140 (output_dim_end - output_dim_start) > 1) {
3141 // We don't support the case when a dynamic dimension is both combined
3142 // with and splitted into other dimensions:
3143 //
3144 // [x, yz]
3145 // | Reshape
3146 // [xy, z]
3147 return Unimplemented(
3148 "Dynamic input dimension to reshape that is both splitted and "
3149 "combined is not supported: %s",
3150 reshape_debug_str);
3151 }
3152
3153 for (auto common_factor : common_factors) {
3154 //
3155 // For reshapes like:
3156 // [<=5]
3157 // | Reshape
3158 // [1, 5]
3159 //
3160 // The input dynamic dimension can go into either dynamic dimensions.
3161 // However, the return value of common factors only returns
3162 // input: 5
3163 // output: 5
3164 //
3165 // We need to expand common factor to include degenerated output
3166 // dimensions:
3167 // input: 5
3168 // output: 1, 5
3169 //
3170 // such that in the logic later on we can consider both dimensions as
3171 // candidate.
3172 if (common_factor.first == input_dim_start) {
3173 output_dim_start = std::min(output_dim_start, common_factor.second);
3174 }
3175 if (common_factor.first == input_dim_end) {
3176 output_dim_end = std::max(output_dim_end, common_factor.second);
3177 }
3178 }
3179
3180 // Calculate output dynamic reshape dimension.
3181 int64_t output_dynamic_dimension = -1;
3182
3183 if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
3184 // If dynamic dimension is size 1, it can only be most-major or
3185 // most-minor.
3186 if (input_dim == 0) {
3187 output_dynamic_dimension = 0;
3188 }
3189 if (input_dim == operand.rank() - 1) {
3190 output_dynamic_dimension = new_sizes.size() - 1;
3191 }
3192
3193 if (output_dynamic_dimension == -1) {
3194 return Unimplemented(
3195 "Dynamic degenerated dimension that's not most-minor nor "
3196 "most-major is not supported: %s",
3197 reshape_debug_str);
3198 }
3199 }
3200
3201 if (output_dynamic_dimension == -1 &&
3202 output_dim_end - output_dim_start == 1) {
3203 // Only one possible output dimension.
3204 output_dynamic_dimension = output_dim_start;
3205 }
3206 if (output_dynamic_dimension == -1 &&
3207 output_dim_end - output_dim_start > 1) {
3208 // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
3209 output_dynamic_dimension = inferred_dimension;
3210 }
3211
3212 if (output_dynamic_dimension != -1) {
3213 // TODO(yunxing): Turn this into a CHECK.
3214 inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
3215 } else {
3216 std::vector<int64> output_non_degenerated;
3217 for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
3218 if (new_sizes[i] != 1) {
3219 output_non_degenerated.push_back(i);
3220 }
3221 }
3222 if (output_non_degenerated.size() == 1) {
3223 inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
3224 }
3225 }
3226 }
3227
3228 return inferred_shape;
3229 }
3230
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)3231 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
3232 const Shape& operand, absl::Span<const int64> dimensions) {
3233 TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
3234
3235 if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
3236 return InvalidArgument(
3237 "Transpose dimensions [%s] are not a permutation of the operand "
3238 "dimensions (operand shape is %s).",
3239 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3240 }
3241
3242 // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
3243 // we need output[i]=input[dimensions[i]] which is
3244 // Permute(Inverse(dimensions),input).
3245 return ShapeUtil::PermuteDimensions(dimensions, operand);
3246 }
3247
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)3248 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
3249 const Shape& min, const Shape& operand, const Shape& max) {
3250 TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
3251 TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
3252 TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
3253
3254 if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
3255 !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
3256 return InvalidArgument(
3257 "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
3258 ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
3259 }
3260 return operand;
3261 }
3262
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3263 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
3264 const Shape& pred, const Shape& on_true, const Shape& on_false) {
3265 TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
3266 TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
3267 TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
3268
3269 if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
3270 return InvalidArgument(
3271 "Operands to select must be the same shape; got %s and %s.",
3272 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3273 }
3274 if (pred.element_type() != PRED) {
3275 return InvalidArgument(
3276 "Select's pred operand must have PRED element type; got %s.",
3277 ShapeUtil::HumanString(pred));
3278 }
3279 if (!Shape::Equal()
3280 .IgnoreElementType()
3281 .IgnoreLayout()
3282 .IgnoreDynamicDimension()(pred, on_true)) {
3283 return InvalidArgument(
3284 "Operands to select and predicate must be the same shape; got %s and "
3285 "%s.",
3286 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3287 }
3288
3289 return ShapeUtil::ChangeElementType(
3290 pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3291 }
3292
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3293 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
3294 const Shape& pred, const Shape& on_true, const Shape& on_false) {
3295 // Select only defines the top-level buffer, so if it's a tuple, the two
3296 // input must match exactly.
3297 if (!ShapeUtil::Compatible(on_true, on_false)) {
3298 return InvalidArgument(
3299 "Operands to tuple-select must be the same shape; got %s and %s.",
3300 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3301 }
3302 if (pred.element_type() != PRED) {
3303 return InvalidArgument(
3304 "TupleSelect's pred operand must have PRED element type; got %s.",
3305 ShapeUtil::HumanString(pred));
3306 }
3307 if (!ShapeUtil::IsScalar(pred)) {
3308 return InvalidArgument(
3309 "TupleSelect operation with non-scalar predicate: %s.",
3310 ShapeUtil::HumanString(pred));
3311 }
3312 return on_true;
3313 }
3314
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3315 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3316 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3317 // The applied function's arity equals the number of arguments.
3318 if (arg_shapes.size() != to_apply.parameters_size()) {
3319 string computation_signature = ShapeUtil::HumanString(to_apply);
3320 string argument_shapes =
3321 StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
3322 absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3323 });
3324 return InvalidArgument(
3325 "Call applied function arity must match number of arguments; got: "
3326 "arity: %d, arguments: %u; computation signature: %s; argument "
3327 "shapes: [%s].",
3328 to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3329 argument_shapes);
3330 }
3331
3332 // All arguments must be compatible with the program shape.
3333 for (int i = 0; i < arg_shapes.size(); ++i) {
3334 const Shape& arg_shape = *arg_shapes[i];
3335 const Shape& param_shape = to_apply.parameters(i);
3336 if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3337 return InvalidArgument(
3338 "Call parameter must match argument; got parameter %d shape: %s, "
3339 "argument shape: %s.",
3340 i, ShapeUtil::HumanString(param_shape),
3341 ShapeUtil::HumanString(arg_shape));
3342 }
3343 }
3344
3345 return to_apply.result();
3346 }
3347
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3348 static Status ValidateGatherDimensionNumbers(
3349 const Shape& input_shape, absl::Span<const int64> start_indices_shape,
3350 const GatherDimensionNumbers& dim_numbers) {
3351 if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3352 return InvalidArgument(
3353 "Output window dimensions in gather op must be ascending; got: %s.",
3354 StrJoin(dim_numbers.offset_dims(), ", "));
3355 }
3356
3357 if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3358 dim_numbers.offset_dims().end()) {
3359 return InvalidArgument(
3360 "Output window dimensions in gather op must not repeat; got: %s.",
3361 StrJoin(dim_numbers.offset_dims(), ", "));
3362 }
3363
3364 const int64_t output_offset_dim_count = dim_numbers.offset_dims_size();
3365 const int64_t output_shape_rank =
3366 output_offset_dim_count + start_indices_shape.size() - 1;
3367
3368 for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3369 int64_t offset_dim = dim_numbers.offset_dims(i);
3370 if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3371 return InvalidArgument(
3372 "Offset dimension %d in gather op is out of bounds; got %d, but "
3373 "should "
3374 "have been in [0,%d).",
3375 i, offset_dim, output_shape_rank);
3376 }
3377 }
3378
3379 if (dim_numbers.start_index_map_size() !=
3380 start_indices_shape[dim_numbers.index_vector_dim()]) {
3381 return InvalidArgument(
3382 "Gather op has %d elements in start_index_map and the "
3383 "bound of dimension index_vector_dim=%d of start_indices is "
3384 "%d. These two numbers must be equal.",
3385 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3386 start_indices_shape[dim_numbers.index_vector_dim()]);
3387 }
3388
3389 for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3390 int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3391 if (operand_dim_for_start_index_i < 0 ||
3392 operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3393 return InvalidArgument(
3394 "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3395 input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3396 }
3397 }
3398
3399 std::vector<int64> sorted_start_index_map(
3400 dim_numbers.start_index_map().begin(),
3401 dim_numbers.start_index_map().end());
3402
3403 absl::c_sort(sorted_start_index_map);
3404
3405 if (absl::c_adjacent_find(sorted_start_index_map) !=
3406 sorted_start_index_map.end()) {
3407 return InvalidArgument(
3408 "Repeated dimensions are not allowed in start_index_map; "
3409 "got: %s.",
3410 StrJoin(dim_numbers.start_index_map(), ", "));
3411 }
3412
3413 for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3414 if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3415 return InvalidArgument(
3416 "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3417 "%d), got: %d.",
3418 input_shape.dimensions_size(), collapsed_dim);
3419 }
3420 }
3421
3422 if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3423 return InvalidArgument(
3424 "collapsed_slice_dims in gather op must be sorted; got: %s",
3425 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3426 }
3427
3428 if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3429 dim_numbers.collapsed_slice_dims().end()) {
3430 return InvalidArgument(
3431 "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3432 "got: %s.",
3433 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3434 }
3435
3436 return Status::OK();
3437 }
3438
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)3439 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3440 const Shape& input_shape, const Shape& start_indices_shape,
3441 const GatherDimensionNumbers& gather_dim_numbers,
3442 absl::Span<const int64> slice_sizes) {
3443 TF_RETURN_IF_ERROR(
3444 ExpectArray(input_shape, "input tensor operand of gather op"));
3445 TF_RETURN_IF_ERROR(
3446 ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3447
3448 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3449 return InvalidArgument(
3450 "Gather indices parameter must be an integral tensor; got %s.",
3451 ShapeUtil::HumanString(start_indices_shape));
3452 }
3453
3454 // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3455 // index_vector_dim is rank(P). The bounds of this expanded shape is
3456 // stored in expanded_start_indices_shape.
3457
3458 if (start_indices_shape.dimensions_size() <
3459 gather_dim_numbers.index_vector_dim() ||
3460 gather_dim_numbers.index_vector_dim() < 0) {
3461 return InvalidArgument(
3462 "Gather index leaf dimension must be within [0, rank(start_indices) + "
3463 "1). rank(start_indices) is %d and gather index leaf dimension is "
3464 "%d.",
3465 start_indices_shape.dimensions_size(),
3466 gather_dim_numbers.index_vector_dim());
3467 }
3468
3469 std::vector<int64> expanded_start_indices_shape;
3470 // Also tracks if an output dimension is dynamic.
3471 std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3472 expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3473 expanded_start_indices_shape_dynamic_dimensions.reserve(
3474 start_indices_shape.dimensions_size());
3475 absl::c_copy(start_indices_shape.dimensions(),
3476 std::back_inserter(expanded_start_indices_shape));
3477 absl::c_copy(
3478 start_indices_shape.dynamic_dimensions(),
3479 std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3480 if (expanded_start_indices_shape.size() ==
3481 gather_dim_numbers.index_vector_dim()) {
3482 expanded_start_indices_shape.push_back(1);
3483 expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3484 }
3485
3486 TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3487 input_shape, expanded_start_indices_shape, gather_dim_numbers));
3488
3489 if (slice_sizes.size() != input_shape.dimensions_size()) {
3490 return InvalidArgument(
3491 "Gather op must have one slice size for every input dimension; got: "
3492 "len(slice_sizes)=%lu, input_shape.rank=%d.",
3493 slice_sizes.size(), input_shape.dimensions_size());
3494 }
3495
3496 if (slice_sizes.size() !=
3497 gather_dim_numbers.offset_dims_size() +
3498 gather_dim_numbers.collapsed_slice_dims_size()) {
3499 return InvalidArgument(
3500 "All components of the offset index in a gather op must either be a "
3501 "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3502 "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3503 slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3504 StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3505 }
3506
3507 for (int i = 0; i < slice_sizes.size(); i++) {
3508 int64_t slice_size = slice_sizes[i];
3509 int64_t corresponding_input_size = input_shape.dimensions(i);
3510 if (slice_size < 0 || slice_size > corresponding_input_size) {
3511 return InvalidArgument(
3512 "Slice size at index %d in gather op is out of range, must be "
3513 "within [0, %d), got %d.",
3514 i, corresponding_input_size + 1, slice_size);
3515 }
3516 }
3517
3518 for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3519 if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3520 return InvalidArgument(
3521 "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3522 "is %d for index %d at position %d.",
3523 slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3524 gather_dim_numbers.collapsed_slice_dims(i), i);
3525 }
3526 }
3527
3528 int64_t result_rank = gather_dim_numbers.offset_dims_size() +
3529 (expanded_start_indices_shape.size() - 1);
3530 int64_t offset_dims_seen = 0;
3531 int64_t gather_dims_seen = 0;
3532 std::vector<int64> output_dim_bounds;
3533 output_dim_bounds.reserve(result_rank);
3534
3535 std::vector<bool> output_dim_is_dynamic;
3536 output_dim_is_dynamic.reserve(result_rank);
3537 for (int64_t i = 0; i < result_rank; i++) {
3538 int64_t current_bound;
3539 bool dim_dynamic = false;
3540 bool is_window_index =
3541 absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3542 if (is_window_index) {
3543 while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3544 offset_dims_seen)) {
3545 offset_dims_seen++;
3546 }
3547 // Gathering an entire dynamic dimension creates dynamic dimension.
3548 //
3549 // e.g.,:
3550 //
3551 // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3552 //
3553 // creates
3554 //
3555 // [<=2, 1]
3556 if (slice_sizes[offset_dims_seen] ==
3557 input_shape.dimensions(offset_dims_seen)) {
3558 dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3559 }
3560 current_bound = slice_sizes[offset_dims_seen++];
3561 } else {
3562 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3563 gather_dims_seen++;
3564 }
3565 // Forward dynamic dimensions from indices.
3566 dim_dynamic =
3567 expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3568
3569 current_bound = expanded_start_indices_shape[gather_dims_seen++];
3570 }
3571 output_dim_is_dynamic.push_back(dim_dynamic);
3572 output_dim_bounds.push_back(current_bound);
3573 }
3574
3575 return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3576 output_dim_is_dynamic);
3577 }
3578
3579 namespace {
3580
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3581 Status ValidateScatterDimensionNumbers(
3582 const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3583 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3584 // Validate update_window_dims in ScatterDimensionNumbers.
3585 if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3586 return InvalidArgument(
3587 "update_window_dims in scatter op must be sorted; got: %s.",
3588 StrJoin(dim_numbers.update_window_dims(), ", "));
3589 }
3590 if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3591 dim_numbers.update_window_dims().end()) {
3592 return InvalidArgument(
3593 "update_window_dims in scatter op must not repeat; got: %s.",
3594 StrJoin(dim_numbers.update_window_dims(), ", "));
3595 }
3596 const int64_t updates_rank = updates_shape.rank();
3597 for (int64_t window_dim : dim_numbers.update_window_dims()) {
3598 if (window_dim < 0 || window_dim >= updates_rank) {
3599 return InvalidArgument(
3600 "Invalid update_window_dims set in scatter op; valid range is [0, "
3601 "%d). got: %d.",
3602 updates_rank, window_dim);
3603 }
3604 }
3605
3606 // Validate inserted_window_dims in ScatterDimensionNumbers.
3607 if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3608 return InvalidArgument(
3609 "inserted_window_dims in scatter op must be sorted; got: %s.",
3610 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3611 }
3612 if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3613 dim_numbers.inserted_window_dims().end()) {
3614 return InvalidArgument(
3615 "inserted_window_dims in scatter op must not repeat; got: %s.",
3616 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3617 }
3618 for (int64_t inserted_dim : dim_numbers.inserted_window_dims()) {
3619 if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3620 return InvalidArgument(
3621 "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3622 "%d), got: %d.",
3623 operand_shape.dimensions_size(), inserted_dim);
3624 }
3625 }
3626
3627 // Validate window size.
3628 auto window_size = dim_numbers.update_window_dims_size() +
3629 dim_numbers.inserted_window_dims_size();
3630 if (window_size != operand_shape.rank()) {
3631 return InvalidArgument(
3632 "Scatter op has window of size %d; doesn't match operand of rank %d.",
3633 window_size, operand_shape.rank());
3634 }
3635
3636 // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3637 if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3638 scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3639 return InvalidArgument(
3640 "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3641 "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3642 "These two numbers must be equal.",
3643 dim_numbers.scatter_dims_to_operand_dims_size(),
3644 dim_numbers.index_vector_dim(),
3645 scatter_indices_shape[dim_numbers.index_vector_dim()]);
3646 }
3647 for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3648 int64_t scatter_dim_to_operand_dim =
3649 dim_numbers.scatter_dims_to_operand_dims(i);
3650 if (scatter_dim_to_operand_dim < 0 ||
3651 scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3652 return InvalidArgument(
3653 "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3654 "got: %d->%d.",
3655 operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3656 }
3657 }
3658 std::vector<int64> sorted_scatter_dims_to_operand_dims(
3659 dim_numbers.scatter_dims_to_operand_dims().begin(),
3660 dim_numbers.scatter_dims_to_operand_dims().end());
3661 absl::c_sort(sorted_scatter_dims_to_operand_dims);
3662 if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3663 sorted_scatter_dims_to_operand_dims.end()) {
3664 return InvalidArgument(
3665 "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3666 "got: %s.",
3667 StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3668 }
3669
3670 return Status::OK();
3671 }
3672
3673 } // namespace
3674
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3675 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3676 const Shape& operand_shape, const Shape& scatter_indices_shape,
3677 const Shape& updates_shape, const ProgramShape& to_apply_shape,
3678 const ScatterDimensionNumbers& scatter_dim_numbers) {
3679 TF_RETURN_IF_ERROR(
3680 ExpectArray(operand_shape, "operand tensor of scatter op"));
3681 TF_RETURN_IF_ERROR(
3682 ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3683 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3684
3685 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3686 return InvalidArgument(
3687 "Scatter indices parameter must be an integral tensor; got %s.",
3688 ShapeUtil::HumanString(scatter_indices_shape));
3689 }
3690
3691 if (scatter_indices_shape.dimensions_size() <
3692 scatter_dim_numbers.index_vector_dim() ||
3693 scatter_dim_numbers.index_vector_dim() < 0) {
3694 return InvalidArgument(
3695 "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3696 " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3697 "is %d.",
3698 scatter_indices_shape.dimensions_size(),
3699 scatter_dim_numbers.index_vector_dim());
3700 }
3701
3702 // Check if the update computation has a proper shape as a reduction.
3703 const Shape init_value_shape =
3704 ShapeUtil::MakeShape(operand_shape.element_type(), {});
3705 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3706 {updates_shape.element_type()},
3707 /*inputs=*/1));
3708
3709 std::vector<int64> expanded_scatter_indices_shape =
3710 SpanToVector(scatter_indices_shape.dimensions());
3711 if (expanded_scatter_indices_shape.size() ==
3712 scatter_dim_numbers.index_vector_dim()) {
3713 expanded_scatter_indices_shape.push_back(1);
3714 }
3715
3716 int64_t expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3717 scatter_dim_numbers.update_window_dims_size();
3718 if (updates_shape.rank() != expected_updates_rank) {
3719 return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3720 expected_updates_rank, updates_shape.rank());
3721 }
3722
3723 TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3724 operand_shape, expanded_scatter_indices_shape, updates_shape,
3725 scatter_dim_numbers));
3726
3727 int64_t inserted_dims_seen = 0;
3728 std::vector<int64> max_update_slice_sizes;
3729 for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3730 if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3731 scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3732 ++inserted_dims_seen;
3733 } else {
3734 max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3735 }
3736 }
3737 for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3738 auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3739 if (updates_shape.dimensions(update_window_dim) >
3740 max_update_slice_sizes[i]) {
3741 return InvalidArgument(
3742 "Bounds of the window dimensions of updates must not exceed the "
3743 "bounds of the corresponding dimensions of operand. For dimension "
3744 "%d, updates bound is %d, operand bound is %d.",
3745 update_window_dim, updates_shape.dimensions(update_window_dim),
3746 max_update_slice_sizes[i]);
3747 }
3748 }
3749
3750 int64_t scatter_dims_seen = 0;
3751 for (int64_t i = 0; i < updates_shape.rank(); ++i) {
3752 bool is_update_window_dim =
3753 absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3754 if (is_update_window_dim) {
3755 continue;
3756 }
3757 if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3758 ++scatter_dims_seen;
3759 }
3760 if (updates_shape.dimensions(i) !=
3761 expanded_scatter_indices_shape[scatter_dims_seen]) {
3762 return InvalidArgument(
3763 "Bounds of the scatter dimensions of updates must be same as the "
3764 "bounds of the corresponding dimensions of scatter indices. For "
3765 "scatter dimension %d, updates bound is %d, scatter_indices "
3766 "bound is %d.",
3767 i, updates_shape.dimensions(i),
3768 expanded_scatter_indices_shape[scatter_dims_seen]);
3769 }
3770 ++scatter_dims_seen;
3771 }
3772
3773 return operand_shape;
3774 }
3775
3776 } // namespace xla
3777