1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/math/math_util.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/protobuf.h"
40
41 namespace xla {
42 namespace {
43
44 using absl::StrFormat;
45 using absl::StrJoin;
46
47 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)48 bool AllUnique(absl::Span<const int64> slice) {
49 return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
50 }
51
ExpectArray(const Shape & shape,absl::string_view op_type)52 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
53 if (!shape.IsArray()) {
54 return InvalidArgument("Expected array argument for %s, but got %s.",
55 string(op_type), ShapeUtil::HumanString(shape));
56 }
57 return Status::OK();
58 }
59
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64 inputs)60 Status VerifyReducerShape(const ProgramShape& reducer_shape,
61 absl::Span<const Shape* const> init_value_shapes,
62 absl::Span<const PrimitiveType> input_element_types,
63 int64 inputs) {
64 if (reducer_shape.parameters_size() != inputs * 2) {
65 return InvalidArgument(
66 "Reduction function must take %d parameters, but "
67 "takes %d parameter(s).",
68 inputs * 2, reducer_shape.parameters_size());
69 }
70
71 const Shape& accumulator_shape = reducer_shape.result();
72 std::vector<const Shape*> accumulator_subshapes;
73 if (accumulator_shape.IsArray()) {
74 if (inputs != 1) {
75 return InvalidArgument(
76 "Reduction function must produce a tuple with %d elements, but "
77 "produces a scalar",
78 inputs);
79 }
80 accumulator_subshapes.push_back(&accumulator_shape);
81 } else if (accumulator_shape.IsTuple()) {
82 if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
83 return InvalidArgument(
84 "Reduction function must produce a tuple with %d elements, but has "
85 "%d elements",
86 inputs, ShapeUtil::TupleElementCount(accumulator_shape));
87 }
88 for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
89 accumulator_subshapes.push_back(&element_shape);
90 }
91 } else {
92 return InvalidArgument(
93 "Reduction function must produce a scalar or tuple of scalars, but has "
94 "shape: %s",
95 ShapeUtil::HumanString(accumulator_shape));
96 }
97
98 for (const Shape* element_shape : accumulator_subshapes) {
99 if (element_shape->rank() != 0) {
100 return InvalidArgument(
101 "Reduction function must return a scalar or tuple of scalars but "
102 "returns shape: %s",
103 ShapeUtil::HumanString(accumulator_shape));
104 }
105 }
106
107 for (int64 i = 0; i < inputs; ++i) {
108 // Check that the accumulator can be passed in as the first argument.
109 // Note: comparing here and below with Compatible since we don't care about
110 // layout in scalars - see b/26668201 for a longer-term vision.
111 if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
112 reducer_shape.parameters(i))) {
113 return InvalidArgument(
114 "Reduction function's %d-th parameter shape differs from the "
115 "result shape: %s vs %s",
116 i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
117 ShapeUtil::HumanString(*accumulator_subshapes[i]));
118 }
119 // Check that init_value's shapes are suitable for reducer_shape.
120 if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
121 *init_value_shapes[i])) {
122 return InvalidArgument(
123 "Reduction function's accumulator shape at index %d differs from "
124 "the init_value shape: %s vs %s",
125 i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
126 ShapeUtil::HumanString(*init_value_shapes[i]));
127 }
128 // Check that the inputs can be passed in as the non-accumulator arguments.
129 const Shape input_element_shape =
130 ShapeUtil::MakeShape(input_element_types[i], {});
131 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
132 input_element_shape, reducer_shape.parameters(inputs + i))) {
133 return InvalidArgument(
134 "Reduction function's %d-th parameter shape differs from the "
135 "input type element type: %s vs %s",
136 inputs + i,
137 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
138 ShapeUtil::HumanString(input_element_shape));
139 }
140 // Check that the accumulator and inputs to the reducer function match.
141 // If the accumulator is scalar, it must have the same type as the inputs
142 // (up to fp precision). If it is a tuple, then the k-th element of the
143 // tuple must have the same type as the K-th input (again, up to fp
144 // precision.)
145 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
146 *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
147 return InvalidArgument(
148 "Reduction function's %d-th parameter shape must "
149 "match the result shape, but got %s vs %s.",
150 inputs + i,
151 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
152 ShapeUtil::HumanString(*accumulator_subshapes[i]));
153 }
154 }
155
156 return Status::OK();
157 }
158
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)159 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
160 const Window& window,
161 PrimitiveType element_type,
162 bool allow_negative_padding) {
163 if (window.dimensions_size() != base_shape.rank()) {
164 return InvalidArgument(
165 "Window has dimension %d but base shape has dimension %d.",
166 window.dimensions_size(), base_shape.rank());
167 }
168
169 std::vector<int64> output_dimensions(window.dimensions_size());
170 std::vector<bool> output_is_dynamic(window.dimensions_size());
171 for (int64 i = 0; i < window.dimensions_size(); ++i) {
172 const auto& dim = window.dimensions(i);
173 if (dim.size() <= 0) {
174 return InvalidArgument("Window %s has a non-positive dimension.",
175 window.DebugString());
176 }
177 if (dim.stride() <= 0) {
178 return InvalidArgument("Window %s has a non-positive stride.",
179 window.DebugString());
180 }
181 if (!allow_negative_padding && dim.padding_low() < 0) {
182 return InvalidArgument("Window %s has a negative low padding.",
183 window.DebugString());
184 }
185 if (!allow_negative_padding && dim.padding_high() < 0) {
186 return InvalidArgument("Window %s has a negative high padding.",
187 window.DebugString());
188 }
189 if (dim.base_dilation() < 1) {
190 return InvalidArgument(
191 "Window %s has a non-positive base area dilation factor.",
192 window.DebugString());
193 }
194 if (dim.window_dilation() < 1) {
195 return InvalidArgument(
196 "Window %s has a non-positive window dilation factor.",
197 window.DebugString());
198 }
199
200 if (base_shape.is_dynamic_dimension(i) &&
201 !window_util::IsTrivialWindowDimension(dim)) {
202 return Unimplemented(
203 "Dynamic shape is not supported for non trivial window: %s",
204 window_util::ToString(window));
205 }
206
207 const int64 dilated_base = window_util::DilatedBound(
208 ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
209 const int64 padded_dilated_base =
210 dim.padding_low() + dilated_base + dim.padding_high();
211 const int64 dilated_window =
212 window_util::DilatedBound(dim.size(), dim.window_dilation());
213
214 output_dimensions[i] = window_util::StridedBound(
215 padded_dilated_base, dilated_window, dim.stride());
216 output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
217 }
218
219 return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
220 output_is_dynamic);
221 }
222
223 } // namespace
224
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)225 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
226 HloOpcode opcode, const HloInstruction* operand) {
227 return InferUnaryOpShape(opcode, operand->shape());
228 }
229
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)230 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
231 HloOpcode opcode, const Shape& shape) {
232 // There is no copy operation at the proto level, so handle copy explicitly.
233 // A domain shape is the same as the input one.
234 if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
235 return shape;
236 }
237
238 TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
239
240 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
241 switch (opcode) {
242 case HloOpcode::kFloor:
243 case HloOpcode::kCeil:
244 case HloOpcode::kRoundNearestAfz:
245 if (!ShapeUtil::ElementIsFloating(shape)) {
246 return InvalidArgument(
247 "Expected element type in shape to be floating for %s operation; "
248 "got %s.",
249 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
250 }
251 return shape;
252 case HloOpcode::kCos:
253 case HloOpcode::kSin:
254 case HloOpcode::kExp:
255 case HloOpcode::kExpm1:
256 case HloOpcode::kLog:
257 case HloOpcode::kLog1p:
258 case HloOpcode::kRsqrt:
259 case HloOpcode::kSqrt:
260 case HloOpcode::kTanh:
261 if (!ShapeUtil::ElementIsFloating(shape) &&
262 !ShapeUtil::ElementIsComplex(shape)) {
263 return InvalidArgument(
264 "Expected element type in shape to be floating or complex for %s "
265 "operation; got %s.",
266 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
267 }
268 return shape;
269 case HloOpcode::kReal:
270 case HloOpcode::kImag:
271 if (ShapeUtil::ElementIsComplex(shape)) {
272 return ShapeUtil::ComplexComponentShape(shape);
273 } else if (ShapeUtil::ElementIsFloating(shape)) {
274 return shape;
275 } else {
276 return InvalidArgument(
277 "Expected element type in shape to be floating or complex for "
278 "%s operation; got %s.",
279 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
280 }
281 case HloOpcode::kAbs:
282 if (ShapeUtil::ElementIsComplex(shape)) {
283 return ShapeUtil::ChangeElementType(
284 shape, primitive_util::ComplexComponentType(shape.element_type()));
285 } else if (ShapeUtil::ElementIsSigned(shape)) {
286 return shape;
287 } else {
288 return InvalidArgument(
289 "Expected element type in shape to be floating or complex for "
290 "%s operation; got %s.",
291 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
292 }
293 case HloOpcode::kClz:
294 if (!ShapeUtil::ElementIsIntegral(shape)) {
295 return InvalidArgument(
296 "Expected an integral element type in argument to Clz "
297 "operation; got %s.",
298 PrimitiveType_Name(shape.element_type()));
299 }
300 return shape;
301 case HloOpcode::kNegate:
302 if (!ShapeUtil::ElementIsIntegral(shape) &&
303 !ShapeUtil::ElementIsFloating(shape) &&
304 !ShapeUtil::ElementIsComplex(shape)) {
305 return InvalidArgument(
306 "Expected element type in shape to be integral, floating or "
307 "complex for %s operation; got %s.",
308 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
309 }
310 return shape;
311 case HloOpcode::kPopulationCount:
312 if (!ShapeUtil::ElementIsIntegral(shape)) {
313 return InvalidArgument(
314 "Expected an integral element type in argument to PopulationCount "
315 "operation; got %s.",
316 PrimitiveType_Name(shape.element_type()));
317 }
318 return shape;
319 case HloOpcode::kSign:
320 if (!ShapeUtil::ElementIsSigned(shape) &&
321 !ShapeUtil::ElementIsComplex(shape)) {
322 return InvalidArgument(
323 "Expected element type in shape to be signed or complex for "
324 "%s operation; got %s.",
325 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
326 }
327 return shape;
328
329 case HloOpcode::kNot:
330 if (shape.element_type() != PRED &&
331 !primitive_util::IsIntegralType(shape.element_type())) {
332 return InvalidArgument(
333 "Expected pred or an integral element type in argument to Not "
334 "operation; got %s.",
335 PrimitiveType_Name(shape.element_type()));
336 }
337 return shape;
338
339 case HloOpcode::kIsFinite:
340 if (!ShapeUtil::ElementIsFloating(shape)) {
341 return InvalidArgument(
342 "Expected element type in shape to be floating "
343 "point for IsFinite "
344 "operation; got %s.",
345 PrimitiveType_Name(shape.element_type()));
346 }
347 return ShapeUtil::ChangeElementType(shape, PRED);
348
349 default:
350 return InvalidArgument(
351 "Unknown operation for unary shape inference: \"%s\".",
352 HloOpcodeString(opcode));
353 }
354 }
355
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64 dimension)356 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
357 absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
358 if (arg_shapes.empty()) {
359 return InvalidArgument("Concatenate expects at least one argument.");
360 }
361 if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
362 return InvalidArgument("Concatenate dimension out of bounds: %d.",
363 dimension);
364 }
365 const Shape* arg_shape = nullptr;
366 PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
367 for (const Shape* shape : arg_shapes) {
368 TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
369 if (!arg_shape) {
370 arg_shape = shape;
371 element_type = arg_shape->element_type();
372 continue;
373 }
374 if (arg_shape->rank() != shape->rank()) {
375 return InvalidArgument(
376 "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
377 "(%s).",
378 arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
379 ShapeUtil::HumanString(*shape));
380 }
381 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
382 return InvalidArgument(
383 "Cannot concatenate arrays with different element types: %s vs %s.",
384 PrimitiveType_Name(arg_shape->element_type()),
385 PrimitiveType_Name(shape->element_type()));
386 }
387 for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
388 ++dimension_number) {
389 if (arg_shape->dimensions(dimension_number) !=
390 shape->dimensions(dimension_number)) {
391 if (dimension_number == dimension) {
392 continue; // It's okay to differ in the dimension we're
393 // concatenating.
394 }
395 return InvalidArgument(
396 "Cannot concatenate arrays that differ in dimensions other than "
397 "the one being concatenated (the other array dimensions must be "
398 "the same): %s vs %s in dimension %d.",
399 ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
400 dimension);
401 }
402 }
403 element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
404 }
405
406 std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
407 arg_shape->dimensions().end());
408 for (size_t i = 1; i < arg_shapes.size(); ++i) {
409 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
410 }
411
412 Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
413
414 // Set dynamic dimensions if any input has dynamic dimension.
415 for (const Shape* shape : arg_shapes) {
416 for (int64 i = 0; i < shape->dimensions_size(); ++i) {
417 if (shape->is_dynamic_dimension(i)) {
418 result.set_dynamic_dimension(i, true);
419 }
420 }
421 }
422 return result;
423 }
424
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)425 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
426 const Shape& operand_shape, PrimitiveType new_element_type) {
427 auto old_element_type = operand_shape.element_type();
428 if (primitive_util::IsComplexType(old_element_type) &&
429 !primitive_util::IsComplexType(new_element_type)) {
430 return Unimplemented(
431 "Conversion from complex to real type %s => %s is not implemented.",
432 ShapeUtil::HumanString(operand_shape),
433 PrimitiveType_Name(new_element_type));
434 }
435 if (!operand_shape.IsArray() ||
436 !primitive_util::IsArrayType(new_element_type)) {
437 // Note: we may want to support tuple conversions via this operation in the
438 // future, by recursing into the tuple elements to check all sub-conversions
439 // are valid. For now we just reject them, though.
440 return InvalidArgument(
441 "Convert does not allow non-arrays, so cannot convert from %s to %s.",
442 ShapeUtil::HumanString(operand_shape),
443 PrimitiveType_Name(new_element_type));
444 }
445
446 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
447 }
448
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)449 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
450 const Shape& operand_shape, PrimitiveType new_element_type) {
451 auto old_element_type = operand_shape.element_type();
452 if (primitive_util::IsComplexType(old_element_type) !=
453 primitive_util::IsComplexType(new_element_type)) {
454 return InvalidArgument("Conversion from complex to real type %s => %s.",
455 ShapeUtil::HumanString(operand_shape),
456 PrimitiveType_Name(new_element_type));
457 }
458 if (!operand_shape.IsArray() ||
459 !primitive_util::IsArrayType(new_element_type)) {
460 // Note: we may want to support tuple conversions via this operation in the
461 // future, by recursing into the tuple elements to check all sub-conversions
462 // are valid. For now we just reject them, though.
463 return InvalidArgument(
464 "Cannot convert from or to tuple type; requested conversion: %s => %s.",
465 ShapeUtil::HumanString(operand_shape),
466 PrimitiveType_Name(new_element_type));
467 }
468 if (primitive_util::BitWidth(old_element_type) !=
469 primitive_util::BitWidth(new_element_type)) {
470 return InvalidArgument(
471 "Cannot bitcast types with different bit-widths: %s => %s.",
472 PrimitiveType_Name(old_element_type),
473 PrimitiveType_Name(new_element_type));
474 }
475
476 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
477 }
478
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)479 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
480 const Shape& operand_shape, const int exponent_bits,
481 const int mantissa_bits) {
482 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
483 return InvalidArgument(
484 "Expected element type in shape to be floating point for "
485 "ReducePrecision operation; got %s.",
486 PrimitiveType_Name(operand_shape.element_type()));
487 }
488 if (exponent_bits < 1) {
489 // One exponent bit is necessary to distinguish 0 from infinity. Having
490 // no exponent bits doesn't produce a sensible number, so we require at
491 // least one.
492 return InvalidArgument("Expected exponent_bits >= 1; got %d.",
493 exponent_bits);
494 }
495 if (mantissa_bits < 0) {
496 // A number with no mantissa bits is still meaningful, however.
497 return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
498 mantissa_bits);
499 }
500 return operand_shape;
501 }
502
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)503 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
504 const Shape& operand_shape, const Shape& padding_value_shape,
505 const PaddingConfig& padding_config) {
506 if (!operand_shape.IsArray()) {
507 return InvalidArgument(
508 "Pad operation does not support tuple-shape operands.");
509 }
510 if (!ShapeUtil::IsScalar(padding_value_shape)) {
511 return InvalidArgument(
512 "Pad operation does not support non-scalar padding values.");
513 }
514 if (operand_shape.rank() != padding_config.dimensions_size()) {
515 return InvalidArgument(
516 "The rank of the operand and the padding configuration do not match: "
517 "%s vs %s.",
518 ShapeUtil::HumanString(operand_shape),
519 padding_config.ShortDebugString());
520 }
521 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
522 padding_value_shape)) {
523 return InvalidArgument(
524 "The element types of the operands to Pad do not match.");
525 }
526 if (absl::c_any_of(padding_config.dimensions(),
527 [](const PaddingConfig::PaddingConfigDimension& p) {
528 return p.interior_padding() < 0;
529 })) {
530 return InvalidArgument("Interior padding cannot be negative: %s",
531 padding_config.ShortDebugString());
532 }
533
534 if (!padding_value_shape.is_static()) {
535 return InvalidArgument("Dynamic padding value is not supported");
536 }
537
538 std::vector<int64> dimensions(operand_shape.rank());
539 std::vector<bool> is_dynamic(operand_shape.rank());
540 for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
541 const auto& p = padding_config.dimensions(i);
542 if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 &&
543 p.edge_padding_low() != 0 && p.interior_padding() != 0) {
544 return InvalidArgument(
545 "Dynamic dimension on padding dimension is not supported.");
546 }
547 dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
548 p.edge_padding_high() +
549 std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
550 p.interior_padding();
551 if (dimensions[i] < 0) {
552 return InvalidArgument("Padding result in negative size for dimension %d",
553 i);
554 }
555 is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
556 }
557
558 return ShapeUtil::MakeShape(
559 ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
560 dimensions, is_dynamic);
561 }
562
563 // Current DotDimensionNumbers Requirements:
564 //
565 // Contracting Dimensions:
566 // *) Same number of contracting dimensions on both lhs and rhs.
567 // *) Contracting dimension size must be the same on both lhs and rhs.
568 //
569 // Batch Dimensions:
570 // *) Same number of batch dimensions on both lhs and rhs.
571 // *) Same batch dimension sizes on both lhs and rhs.
572 //
573
574 namespace {
575
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)576 Status ValidateDotDimensionNumbers(
577 const Shape& lhs, const Shape& rhs,
578 const DotDimensionNumbers& dimension_numbers) {
579 // Check that dimension numbers are in range.
580 auto dims_in_range = [](const int64 rank,
581 absl::Span<const int64> contracting_dims,
582 absl::Span<const int64> batch_dims) -> bool {
583 auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
584 return absl::c_all_of(contracting_dims, in_range) &&
585 absl::c_all_of(batch_dims, in_range);
586 };
587
588 absl::Span<const int64> lhs_contracting_dimensions =
589 AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
590 absl::Span<const int64> rhs_contracting_dimensions =
591 AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
592 absl::Span<const int64> lhs_batch_dimensions =
593 AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
594 absl::Span<const int64> rhs_batch_dimensions =
595 AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
596
597 if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
598 lhs_batch_dimensions) ||
599 !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
600 rhs_batch_dimensions)) {
601 return InvalidArgument("A dimension number is out of range in Dot: %s.",
602 dimension_numbers.DebugString());
603 }
604
605 // Check that dimension numbers are unique.
606 auto dims_unique = [](absl::Span<const int64> contracting_dims,
607 absl::Span<const int64> batch_dims) -> bool {
608 absl::flat_hash_set<int64> dim_set;
609 auto is_unique = [&dim_set](int64 i) -> bool {
610 return dim_set.insert(i).second;
611 };
612 return absl::c_all_of(contracting_dims, is_unique) &&
613 absl::c_all_of(batch_dims, is_unique);
614 };
615
616 if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
617 !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
618 return InvalidArgument("A dimension number is not unique in Dot: %s.",
619 dimension_numbers.DebugString());
620 }
621
622 return Status::OK();
623 }
624
625 } // namespace
626
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)627 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
628 const Shape& lhs, const Shape& rhs,
629 const DotDimensionNumbers& dimension_numbers) {
630 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
631 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
632
633 auto fail = [lhs, rhs](const string& addendum) -> Status {
634 string message =
635 StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
636 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
637 if (!addendum.empty()) {
638 message += " " + addendum;
639 }
640 return InvalidArgument("%s", message);
641 };
642
643 // Check if both element types are the same.
644 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
645 return fail("Element types do not match.");
646 }
647
648 // Validate basic properties of dot dimension numbers.
649 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
650
651 // Check that number of contracting dimensions match.
652 if (dimension_numbers.lhs_contracting_dimensions_size() !=
653 dimension_numbers.rhs_contracting_dimensions_size()) {
654 return fail(
655 "Must specify the same number of contracting dimensions for lhs and "
656 "rhs.");
657 }
658 // Check that contracting dimension sizes match.
659 for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
660 ++i) {
661 const int64 lhs_contracting_dimension =
662 dimension_numbers.lhs_contracting_dimensions(i);
663 const int64 rhs_contracting_dimension =
664 dimension_numbers.rhs_contracting_dimensions(i);
665 if (lhs.dimensions(lhs_contracting_dimension) !=
666 rhs.dimensions(rhs_contracting_dimension)) {
667 return fail("Contracting dimension sizes do not match.");
668 }
669 }
670
671 // Check that number of batch dimensions match.
672 if (dimension_numbers.lhs_batch_dimensions_size() !=
673 dimension_numbers.rhs_batch_dimensions_size()) {
674 return fail("Must the same number of batch dimensions for lhs and rhs.");
675 }
676
677 // Check that batch dimension numbers and sizes match.
678 for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
679 if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
680 rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
681 return fail("Batch dimension sizes must match for lhs/rhs.");
682 }
683 }
684
685 // The ranks of lhs and rhs are decremented by 1 respectively due to the
686 // contraction, and added for the rank of the result. When an input tensor is
687 // a scalar, its contribution to the rank of the result is 0.
688 // Generate the result dimensions in order, rhs dimensions followed by lhs
689 // dimensions except the contracted and batch dimensions.
690 std::vector<int64> dimensions;
691 std::vector<bool> is_dynamic;
692 for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
693 dimensions.push_back(lhs.dimensions(lhs_dim));
694 is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
695 }
696 for (int64 i = 0; i < lhs.rank(); i++) {
697 if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
698 i) &&
699 !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
700 dimensions.push_back(lhs.dimensions(i));
701 is_dynamic.push_back(lhs.is_dynamic_dimension(i));
702 }
703 }
704 for (int64 i = 0; i < rhs.rank(); i++) {
705 if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
706 i) &&
707 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
708 dimensions.push_back(rhs.dimensions(i));
709 is_dynamic.push_back(rhs.is_dynamic_dimension(i));
710 }
711 }
712 Shape result = ShapeUtil::MakeShape(
713 ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
714
715 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
716 VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
717 return result;
718 }
719
720 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)721 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
722 const Shape& lhs,
723 const Shape& rhs) {
724 TF_RET_CHECK(lhs.rank() == rhs.rank());
725
726 // The shapes have to be compatible. That is, if some dimension d has a
727 // different size in the two shapes, one of them has to be 1 (a "degenerate"
728 // dimension). In that case, the output shape has the non-1 dimension size
729 // from the lhs/rhs pair in every index.
730 std::vector<int64> output_dimensions(lhs.rank());
731 std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
732 for (int64 i = 0; i < lhs.rank(); ++i) {
733 if (lhs.dimensions(i) == rhs.dimensions(i)) {
734 output_dimensions[i] = lhs.dimensions(i);
735 } else if (lhs.dimensions(i) == 1) {
736 output_dimensions[i] = rhs.dimensions(i);
737 } else if (rhs.dimensions(i) == 1) {
738 output_dimensions[i] = lhs.dimensions(i);
739 } else {
740 return InvalidArgument(
741 "Binary op %s with incompatible shapes: %s and %s.",
742 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
743 ShapeUtil::HumanString(rhs));
744 }
745 }
746
747 // Merge dynamic dimensions from two shapes.
748 for (int64 i = 0; i < rhs.rank(); ++i) {
749 if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
750 output_dimensions_is_dynamic[i] = true;
751 }
752 }
753
754 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
755 output_dimensions, output_dimensions_is_dynamic);
756 }
757
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)758 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
759 const Shape& smaller_shape, const Shape& larger_shape,
760 absl::Span<const int64> broadcast_dimensions) {
761 if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
762 // Reject "magic" inference for binops on different shapes, requiring
763 // the user to provide an explicit broadcast dimension in this case.
764 // See b/25177275 for more details.
765 return InvalidArgument("Automatic shape inference not supported: %s and %s",
766 ShapeUtil::HumanString(smaller_shape),
767 ShapeUtil::HumanString(larger_shape));
768 } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
769 return InvalidArgument(
770 "Size of broadcast_dimensions has to match lower-rank operand's "
771 "rank; "
772 " lower-rank operand's rank is %d, size of broadcast_dimensions is "
773 "%u.",
774 smaller_shape.rank(), broadcast_dimensions.size());
775 }
776
777 // broadcast_dimensions is a sequence of dimensions; its length is equal to
778 // the rank of the lower-rank operand. The lower-rank operand's dimensions
779 // have to be compatible with the higher-rank operand's dimensions at indices
780 // specified by broadcast_dimensions. Here compatible means the dimension
781 // sizes are equal or in one of the shapes the dimension size is
782 // one. Examples:
783 //
784 // smaller_shape larger_shape broadcast_dimensions output_shape
785 // [] [2, 3] {} [2, 3]
786 // [3] [4, 3] {1} [4, 3]
787 // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
788 // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
789 // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
790 //
791 // The column output_shape may not be the final shape of the XLA
792 // operation. After the "InDim" broadcasting implemented in this function
793 // expands the rank, degenerate-dimension broadcasting (implemented in
794 // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
795 // up to match the dimension size of the other operand. For example, consider
796 // the row in the table above with a smaller_shape of [2, 1]. The shape
797 // returned by this function is [2, 3, 1] (output_shape) however, the result
798 // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
799 // broadcasting.
800 //
801 // Invalid broadcasts:
802 //
803 // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
804 // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
805 // dimension zero of smaller_shape(size 3). **Zero here comes from the value
806 // in broadcast_dimensions.
807 //
808 // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
809 // Reason: Dimension one of larger_shape (size 3) is not compatible with
810 // dimension zero of smaller_shape(size 2)
811
812 // The output shape is initially the larger_shape. Sizes of dimensions
813 // specified in broadcast_dimensions are then changed to match the
814 // corresponding dimension size in smaller_shape.
815 Shape output_shape(larger_shape);
816 output_shape.set_element_type(
817 ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
818
819 for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
820 int64 dimension_to_match = broadcast_dimensions.at(i);
821 if (dimension_to_match < 0) {
822 return InvalidArgument(
823 "Broadcast dimension number (%d) cannot be negative.",
824 dimension_to_match);
825 }
826 if (dimension_to_match >= larger_shape.dimensions_size()) {
827 return InvalidArgument(
828 "Broadcast dimension number (%d) too large; higher-rank "
829 "operand has rank %d.",
830 dimension_to_match, larger_shape.dimensions_size());
831 }
832 int64 small_dimension_size = smaller_shape.dimensions(i);
833 int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
834 bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
835 bool large_is_dynamic =
836 larger_shape.is_dynamic_dimension(dimension_to_match);
837 // Dimension sizes must be compatible: match or be degenerate (degenerate
838 // case is handled by degenerate dimension broadcasting which occurs after
839 // InDim broadcasting).
840 if (small_dimension_size != large_dimension_size &&
841 small_dimension_size != 1 && large_dimension_size != 1) {
842 return InvalidArgument(
843 "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
844 small_dimension_size, large_dimension_size,
845 ShapeUtil::HumanString(smaller_shape),
846 ShapeUtil::HumanString(larger_shape));
847 }
848 if (small_is_dynamic != large_is_dynamic) {
849 if (small_dimension_size == large_dimension_size ||
850 (small_dimension_size == 1 && !small_is_dynamic) ||
851 (large_dimension_size == 1 && !large_is_dynamic)) {
852 // Do nothing. It's OK when the size-1 dimension is not static.
853 } else {
854 return InvalidArgument(
855 "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
856 ShapeUtil::HumanString(smaller_shape),
857 ShapeUtil::HumanString(larger_shape));
858 }
859 }
860 // Make sure the broadcast dimensions are listed in a strictly increasing
861 // order.
862 if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
863 return InvalidArgument(
864 "Broadcast dimensions order is wrong: %d comes after %d.",
865 dimension_to_match, broadcast_dimensions.at(i - 1));
866 }
867
868 output_shape.set_dimensions(dimension_to_match, small_dimension_size);
869 output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
870 }
871
872 return output_shape;
873 }
874
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)875 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
876 HloOpcode operation, const Shape& lhs, const Shape& rhs,
877 absl::Span<const int64> broadcast_dimensions) {
878 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
879 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
880
881 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
882 return InvalidArgument(
883 "Binary op %s with different element types: %s and %s.",
884 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
885 ShapeUtil::HumanString(rhs));
886 }
887
888 if (lhs.rank() == rhs.rank()) {
889 std::vector<int64> identity_dims(lhs.rank());
890 std::iota(identity_dims.begin(), identity_dims.end(), 0);
891 if (!broadcast_dimensions.empty() &&
892 broadcast_dimensions != identity_dims) {
893 return InvalidArgument(
894 "Broadcast dimensions field must either be not set or be the "
895 "identity on binary operations with operands of the same rank.");
896 }
897 }
898
899 if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
900 // If the shapes are the same other than layout, the output shape is the
901 // same (elementwise op).
902 Shape result = ShapeUtil::ChangeElementType(
903 lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
904
905 for (int64 i = 0; i < rhs.rank(); ++i) {
906 if (rhs.is_dynamic_dimension(i)) {
907 result.set_dynamic_dimension(i, true);
908 }
909 }
910
911 return result;
912
913 } else if (lhs.rank() == rhs.rank()) {
914 return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
915 } else {
916 // Ranks do not match, so perform InDim broadcasting using
917 // broadcast_dimensions. Scalar broadcasting is a special case of this.
918 const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
919 const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
920
921 // After InDim broadcasting, perform degenerate dimensions broadcasting.
922 TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
923 InferInDimBroadcastShape(smaller_shape, larger_shape,
924 broadcast_dimensions));
925
926 return InferDegenerateDimensionBroadcastShape(
927 operation, indim_broadcast_shape, larger_shape);
928 }
929 }
930
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)931 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
932 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
933 return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
934 /*broadcast_dimensions=*/{});
935 }
936
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)937 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
938 HloOpcode opcode, const Shape& lhs, const Shape& rhs,
939 absl::Span<const int64> broadcast_dimensions) {
940 VLOG(2) << StrFormat(
941 "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
942 HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
943 ShapeUtil::HumanStringWithLayout(rhs),
944 StrJoin(broadcast_dimensions, ", "));
945
946 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
947 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
948
949 TF_RETURN_IF_ERROR(ExpectArray(
950 lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
951 TF_RETURN_IF_ERROR(ExpectArray(
952 rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
953 switch (opcode) {
954 case HloOpcode::kMaximum:
955 case HloOpcode::kMinimum:
956 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
957 broadcast_dimensions);
958
959 case HloOpcode::kSubtract:
960 case HloOpcode::kAdd:
961 case HloOpcode::kAtan2:
962 case HloOpcode::kPower:
963 case HloOpcode::kDivide:
964 case HloOpcode::kRemainder:
965 case HloOpcode::kMultiply:
966 case HloOpcode::kShiftLeft:
967 case HloOpcode::kShiftRightArithmetic:
968 case HloOpcode::kShiftRightLogical:
969 if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
970 return InvalidArgument(
971 "Expected element type in shape to be arithmetic type for "
972 "operation %s; got PRED.",
973 HloOpcodeString(opcode));
974 }
975 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
976 broadcast_dimensions);
977
978 case HloOpcode::kComplex: {
979 if (!ShapeUtil::ElementIsFloating(lhs)) {
980 return InvalidArgument(
981 "Expected element type in shape to be floating for complex compose "
982 "operation; got %s.",
983 PrimitiveType_Name(lhs.element_type()));
984 }
985 TF_ASSIGN_OR_RETURN(const Shape& shape,
986 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
987 broadcast_dimensions));
988 if (lhs.element_type() == F32 && rhs.element_type() == F32) {
989 return ShapeUtil::ChangeElementType(shape, C64);
990 } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
991 return ShapeUtil::ChangeElementType(shape, C128);
992 } else {
993 return Unimplemented("Complex component type is not implemented.");
994 }
995 }
996 case HloOpcode::kAnd:
997 case HloOpcode::kOr:
998 case HloOpcode::kXor:
999 if (lhs.element_type() != PRED &&
1000 !primitive_util::IsIntegralType(lhs.element_type())) {
1001 return InvalidArgument(
1002 "Expected pred or integral type in argument to and/or operation; "
1003 "got %s.",
1004 PrimitiveType_Name(lhs.element_type()));
1005 }
1006 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1007 broadcast_dimensions);
1008 case HloOpcode::kCompare: {
1009 TF_ASSIGN_OR_RETURN(const Shape& shape,
1010 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1011 broadcast_dimensions));
1012 return ShapeUtil::ChangeElementType(shape, PRED);
1013 }
1014 default:
1015 return Unimplemented(
1016 "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1017 HloOpcodeString(opcode), lhs.ShortDebugString(),
1018 rhs.ShortDebugString());
1019 }
1020 }
1021
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1022 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1023 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1024 const HloInstruction* ehs) {
1025 return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1026 }
1027
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1028 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1029 HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1030 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1031 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1032 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1033 switch (opcode) {
1034 case HloOpcode::kClamp:
1035 return InferClampShape(lhs, rhs, ehs);
1036 case HloOpcode::kSelect:
1037 return InferSelectShape(lhs, rhs, ehs);
1038 case HloOpcode::kTupleSelect:
1039 return InferTupleSelectShape(lhs, rhs, ehs);
1040 default:
1041 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1042 }
1043 }
1044
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1045 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1046 HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1047 std::vector<const Shape*> operand_shapes;
1048 operand_shapes.reserve(operands.size());
1049 for (const HloInstruction* operand : operands) {
1050 operand_shapes.push_back(&operand->shape());
1051 }
1052 return InferVariadicOpShape(opcode, operand_shapes);
1053 }
1054
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1055 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1056 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1057 for (const Shape* shape : operand_shapes) {
1058 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1059 }
1060 switch (opcode) {
1061 case HloOpcode::kTuple: {
1062 Shape result = ShapeUtil::MakeTupleShape({});
1063 result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1064 for (const Shape* shape : operand_shapes) {
1065 ShapeUtil::AppendShapeToTuple(*shape, &result);
1066 }
1067 return result;
1068 }
1069 case HloOpcode::kSort: {
1070 if (operand_shapes.size() == 1) {
1071 return *operand_shapes[0];
1072 } else {
1073 for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
1074 if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1075 *operand_shapes[operand])) {
1076 return InvalidArgument(
1077 "Sort keys and values dimensions must match. "
1078 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1079 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1080 ShapeUtil::HumanString(*operand_shapes[operand]));
1081 }
1082 }
1083 std::vector<Shape> operand_shape_values;
1084 for (const Shape* operand_shape : operand_shapes) {
1085 operand_shape_values.push_back(*operand_shape);
1086 }
1087 return ShapeUtil::MakeTupleShape(operand_shape_values);
1088 }
1089 return InvalidArgument("Unexpected number of operands for sort");
1090 }
1091 default:
1092 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1093 }
1094 }
1095
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1096 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1097 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1098 absl::Span<const int64> dimensions) {
1099 if (arg_shapes.empty()) {
1100 return InvalidArgument("Map expects at least one argument.");
1101 }
1102
1103 // All arguments must have the same shape.
1104 const Shape* arg_shape = arg_shapes[0];
1105 for (size_t i = 1; i < arg_shapes.size(); ++i) {
1106 TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1107
1108 if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1109 continue;
1110 }
1111 if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1112 *arg_shape)) {
1113 if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1114 continue;
1115 }
1116 if (ShapeUtil::IsScalar(*arg_shape)) {
1117 arg_shape = arg_shapes[i];
1118 continue;
1119 }
1120 }
1121
1122 std::vector<string> pieces;
1123 for (const Shape* shape : arg_shapes) {
1124 pieces.push_back(ShapeUtil::HumanString(*shape));
1125 }
1126 return InvalidArgument(
1127 "Map operation requires all operands to have the same shape; got: "
1128 "%s.",
1129 StrJoin(pieces, ", "));
1130 }
1131
1132 // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1133 // only support mapping across all dimensions: i.e. scalar map functions).
1134 if (dimensions.size() != arg_shape->dimensions_size()) {
1135 return InvalidArgument(
1136 "Map applied to a subset of dimensions currently not supported: "
1137 "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1138 arg_shape->dimensions_size(), dimensions.size());
1139 }
1140
1141 // Check that requested map dimensions numbers are monotonically increasing.
1142 for (int i = 0; i < dimensions.size(); ++i) {
1143 if (dimensions[i] != i) {
1144 return InvalidArgument(
1145 "Map requires monotonically increasing dimension numbers; got: %s.",
1146 StrJoin(dimensions, ", "));
1147 }
1148 }
1149
1150 // The applied function's arity equals the number of arguments.
1151 if (arg_shapes.size() != to_apply.parameters_size()) {
1152 return InvalidArgument(
1153 "Map applied function arity must match number of arguments; got: "
1154 "arity: %d, arguments: %u.",
1155 to_apply.parameters_size(), arg_shapes.size());
1156 }
1157
1158 // The parameters should all be scalars, and the output too.
1159 const Shape& output_shape = to_apply.result();
1160 if (!ShapeUtil::IsScalar(output_shape)) {
1161 return InvalidArgument(
1162 "Mapped computation's result has to be a scalar; got: %s.",
1163 ShapeUtil::HumanString(output_shape));
1164 }
1165
1166 for (int i = 0; i < to_apply.parameters_size(); ++i) {
1167 const Shape& parameter_shape = to_apply.parameters(i);
1168
1169 if (!ShapeUtil::IsScalar(parameter_shape)) {
1170 return InvalidArgument(
1171 "Mapped computation's parameter has to be a scalar; "
1172 "got parameter %d shape: %s.",
1173 i, ShapeUtil::HumanString(parameter_shape));
1174 }
1175
1176 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1177 *arg_shape)) {
1178 return InvalidArgument(
1179 "Mapped computation's parameter type has to match argument element "
1180 "type; got parameter %d shape: %s, argument shape: %s.",
1181 i, ShapeUtil::HumanString(parameter_shape),
1182 ShapeUtil::HumanString(*arg_shape));
1183 }
1184 }
1185
1186 return ShapeUtil::MakeShape(output_shape.element_type(),
1187 AsInt64Slice(arg_shape->dimensions()));
1188 }
1189
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1190 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1191 const Shape& operand_shape, const Shape& scale_shape,
1192 const Shape& offset_shape, int64 feature_index) {
1193 TF_RETURN_IF_ERROR(
1194 ExpectArray(operand_shape, "operand of batch norm training"));
1195 TF_RETURN_IF_ERROR(
1196 ExpectArray(offset_shape, "offset input of batch norm training"));
1197 TF_RETURN_IF_ERROR(
1198 ExpectArray(scale_shape, "scale input of batch norm training"));
1199
1200 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1201 Status::OK());
1202 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1203 Status::OK());
1204 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1205 Status::OK());
1206
1207 if (feature_index >= operand_shape.rank()) {
1208 return InvalidArgument(
1209 "Expected feature_index of batch-norm-training to be "
1210 "smaller than the rank of operand_shape; "
1211 "got feature_index %d, and rank %d.",
1212 feature_index, operand_shape.rank());
1213 }
1214
1215 if (feature_index < 0) {
1216 return InvalidArgument(
1217 "Expected feature_index of batch-norm-training to "
1218 "be a non-negative number, got %d.",
1219 feature_index);
1220 }
1221
1222 if (operand_shape.rank() < 1) {
1223 return InvalidArgument(
1224 "Expected the rank of operand to "
1225 "batch-norm-training to be at least 1; got %d.",
1226 operand_shape.rank());
1227 }
1228
1229 if (offset_shape.rank() != 1) {
1230 return InvalidArgument(
1231 "Offset input of batch-norm-training must have"
1232 " rank 1, but has rank %d.",
1233 offset_shape.rank());
1234 }
1235
1236 if (scale_shape.rank() != 1) {
1237 return InvalidArgument(
1238 "Scale input of batch-norm-training must have"
1239 " rank 1, but has rank %d.",
1240 scale_shape.rank());
1241 }
1242
1243 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1244 return InvalidArgument(
1245 "The operand to batch-norm-training must have a floating point "
1246 "element type, but the shape is %s.",
1247 PrimitiveType_Name(operand_shape.element_type()));
1248 }
1249
1250 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1251 operand_shape)) {
1252 return InvalidArgument(
1253 "The inputs should have the same element type for batch-norm-training, "
1254 "but the shape of offset factor is %s "
1255 "and the shape of operand is %s.",
1256 PrimitiveType_Name(offset_shape.element_type()),
1257 PrimitiveType_Name(operand_shape.element_type()));
1258 }
1259
1260 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1261 operand_shape)) {
1262 return InvalidArgument(
1263 "The inputs should have the same element type for batch-norm-training, "
1264 "but the shape of scale factor is %s "
1265 "and the shape of operand is %s.",
1266 PrimitiveType_Name(scale_shape.element_type()),
1267 PrimitiveType_Name(operand_shape.element_type()));
1268 }
1269
1270 const int64 feature_count = operand_shape.dimensions(feature_index);
1271 Shape output_shape_for_mean_and_var =
1272 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1273
1274 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1275 return InvalidArgument(
1276 "The size of offset factor should be the same as feature count,"
1277 "but the size of offset factor is %d "
1278 "and the feature count is %d.",
1279 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1280 }
1281
1282 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1283 return InvalidArgument(
1284 "The size of scale factor should be the same as feature count,"
1285 "but the size of scale factor is %d "
1286 "and the feature count is %d.",
1287 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1288 }
1289
1290 return ShapeUtil::MakeTupleShape({operand_shape,
1291 output_shape_for_mean_and_var,
1292 output_shape_for_mean_and_var});
1293 }
1294
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1295 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1296 const Shape& operand_shape, const Shape& scale_shape,
1297 const Shape& offset_shape, const Shape& mean_shape,
1298 const Shape& variance_shape, int64 feature_index) {
1299 TF_RETURN_IF_ERROR(
1300 ExpectArray(operand_shape, "operand of batch norm inference"));
1301 TF_RETURN_IF_ERROR(
1302 ExpectArray(offset_shape, "offset input of batch norm inference"));
1303 TF_RETURN_IF_ERROR(
1304 ExpectArray(scale_shape, "scale input of batch norm inference"));
1305
1306 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1307 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1308 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1309 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1310 TF_RETURN_IF_ERROR(
1311 ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1312
1313 if (feature_index >= operand_shape.rank()) {
1314 return InvalidArgument(
1315 "Expected feature_index of batch-norm-inference to be "
1316 "smaller than the rank of operand_shape; "
1317 "got feature_index %d, and rank %d.",
1318 feature_index, operand_shape.rank());
1319 }
1320
1321 if (feature_index < 0) {
1322 return InvalidArgument(
1323 "Expected feature_index of batch-norm-inference to "
1324 "be a non-negative number, got %d.",
1325 feature_index);
1326 }
1327
1328 if (operand_shape.rank() < 1) {
1329 return InvalidArgument(
1330 "Expected the rank of operand to "
1331 "batch-norm-inference to be at least 1; got %d.",
1332 operand_shape.rank());
1333 }
1334
1335 if (offset_shape.rank() != 1) {
1336 return InvalidArgument(
1337 "Offset input of batch-norm-inference must have"
1338 " rank 1, but has rank %d.",
1339 offset_shape.rank());
1340 }
1341
1342 if (scale_shape.rank() != 1) {
1343 return InvalidArgument(
1344 "Scale input of batch-norm-inference must have"
1345 " rank 1, but has rank %d.",
1346 scale_shape.rank());
1347 }
1348
1349 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1350 return InvalidArgument(
1351 "The operand to batch-norm-inference must have a floating point "
1352 "element type, but the shape is %s.",
1353 PrimitiveType_Name(operand_shape.element_type()));
1354 }
1355
1356 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1357 operand_shape)) {
1358 return InvalidArgument(
1359 "The inputs should have the same element type for "
1360 "batch-norm-inference, "
1361 "but the shape of offset factor is %s "
1362 "and the shape of operand is %s.",
1363 PrimitiveType_Name(offset_shape.element_type()),
1364 PrimitiveType_Name(operand_shape.element_type()));
1365 }
1366
1367 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1368 operand_shape)) {
1369 return InvalidArgument(
1370 "The inputs should have the same element type for "
1371 "batch-norm-inference, "
1372 "but the shape of scale factor is %s "
1373 "and the shape of operand is %s.",
1374 PrimitiveType_Name(scale_shape.element_type()),
1375 PrimitiveType_Name(operand_shape.element_type()));
1376 }
1377
1378 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1379 operand_shape)) {
1380 return InvalidArgument(
1381 "The inputs should have the same element type for "
1382 "batch-norm-inference, "
1383 "but the shape of mean is %s "
1384 "and the shape of operand is %s.",
1385 PrimitiveType_Name(mean_shape.element_type()),
1386 PrimitiveType_Name(operand_shape.element_type()));
1387 }
1388
1389 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1390 operand_shape)) {
1391 return InvalidArgument(
1392 "The inputs should have the same element type for "
1393 "batch-norm-inference, "
1394 "but the shape of variance is %s "
1395 "and the shape of operand is %s.",
1396 PrimitiveType_Name(mean_shape.element_type()),
1397 PrimitiveType_Name(variance_shape.element_type()));
1398 }
1399
1400 const int64 feature_count = operand_shape.dimensions(feature_index);
1401 Shape output_shape_for_mean_and_var =
1402 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1403
1404 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1405 return InvalidArgument(
1406 "The size of offset factor should be the same as feature count,"
1407 "but the size of offset factor is %d "
1408 "and the feature count is %d.",
1409 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1410 }
1411
1412 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1413 return InvalidArgument(
1414 "The size of scale factor should be the same as feature count,"
1415 "but the size of scale factor is %d "
1416 "and the feature count is %d.",
1417 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1418 }
1419
1420 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1421 return InvalidArgument(
1422 "The size of mean should be the same as feature count,"
1423 "but the size of mean is %d "
1424 "and the feature count is %d.",
1425 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1426 }
1427
1428 if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1429 return InvalidArgument(
1430 "The size of variance should be the same as feature count,"
1431 "but the size of variance is %d "
1432 "and the feature count is %d.",
1433 ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1434 }
1435
1436 return operand_shape;
1437 }
1438
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1439 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1440 const Shape& operand_shape, const Shape& scale_shape,
1441 const Shape& mean_shape, const Shape& var_shape,
1442 const Shape& output_grad_shape, int64 feature_index) {
1443 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1444 TF_RETURN_IF_ERROR(
1445 ExpectArray(scale_shape, "scale input of batch norm grad"));
1446 TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1447 TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1448 TF_RETURN_IF_ERROR(
1449 ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1450
1451 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1452 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1453 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1454 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1455 TF_RETURN_IF_ERROR(
1456 ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1457
1458 if (feature_index >= operand_shape.rank()) {
1459 return InvalidArgument(
1460 "Expected feature_index of batch-norm-grad to be "
1461 "smaller than the rank of operand_shape; "
1462 "got feature_index %d, and rank %d.",
1463 feature_index, operand_shape.rank());
1464 }
1465
1466 if (operand_shape.rank() != output_grad_shape.rank()) {
1467 return InvalidArgument(
1468 "Expected operand_shape of batch-norm-grad to have the same rank as"
1469 " output_grad_shape; got rank(oprand_shape) %d, and"
1470 " rank(output_grad_shape) %d.",
1471 operand_shape.rank(), output_grad_shape.rank());
1472 }
1473
1474 if (mean_shape.rank() != 1) {
1475 return InvalidArgument(
1476 "Mean input of batch-norm-grad must have"
1477 " rank 1, but has rank %d.",
1478 mean_shape.rank());
1479 }
1480
1481 if (scale_shape.rank() != 1) {
1482 return InvalidArgument(
1483 "Scale input of batch-norm-grad must have"
1484 " rank 1, but has rank %d.",
1485 scale_shape.rank());
1486 }
1487
1488 if (var_shape.rank() != 1) {
1489 return InvalidArgument(
1490 "Var input of batch-norm-grad must have"
1491 " rank 1, but has rank %d.",
1492 var_shape.rank());
1493 }
1494
1495 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1496 return InvalidArgument(
1497 "The operand to batch-norm-grad must have a floating point "
1498 "element type, but the shape is %s.",
1499 PrimitiveType_Name(operand_shape.element_type()));
1500 }
1501
1502 if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1503 return InvalidArgument(
1504 "The output_grad to batch-norm-grad must have a floating point "
1505 "element type, but the shape is %s.",
1506 PrimitiveType_Name(output_grad_shape.element_type()));
1507 }
1508
1509 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1510 operand_shape)) {
1511 return InvalidArgument(
1512 "The inputs should have the same element type for batch-norm-grad, "
1513 "but the element type of output_grad is %s "
1514 "and the element type of operand is %s.",
1515 PrimitiveType_Name(output_grad_shape.element_type()),
1516 PrimitiveType_Name(operand_shape.element_type()));
1517 }
1518
1519 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1520 operand_shape)) {
1521 return InvalidArgument(
1522 "The inputs should have the same element type for batch-norm-grad, "
1523 "but the element type of scale factor is %s "
1524 "and the element type of operand is %s.",
1525 PrimitiveType_Name(scale_shape.element_type()),
1526 PrimitiveType_Name(operand_shape.element_type()));
1527 }
1528
1529 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1530 operand_shape)) {
1531 return InvalidArgument(
1532 "The inputs should have the same element type for batch-norm-grad, "
1533 "but the element type of mean is %s "
1534 "and the element type of operand is %s.",
1535 PrimitiveType_Name(mean_shape.element_type()),
1536 PrimitiveType_Name(operand_shape.element_type()));
1537 }
1538
1539 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1540 operand_shape)) {
1541 return InvalidArgument(
1542 "The inputs should have the same element type for batch-norm-grad, "
1543 "but the element type of mean is %s "
1544 "and the element type of operand is %s.",
1545 PrimitiveType_Name(mean_shape.element_type()),
1546 PrimitiveType_Name(operand_shape.element_type()));
1547 }
1548
1549 const int64 feature_count = operand_shape.dimensions(feature_index);
1550
1551 Shape feature_shape =
1552 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1553
1554 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1555 return InvalidArgument(
1556 "The size of mean should be the same as feature count,"
1557 "but the size of offset factor is %d "
1558 "and the feature count is %d.",
1559 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1560 }
1561
1562 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1563 return InvalidArgument(
1564 "The size of scale factor should be the same as feature count,"
1565 "but the size of scale factor is %d "
1566 "and the feature count is %d.",
1567 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1568 }
1569
1570 if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1571 return InvalidArgument(
1572 "The size of variance should be the same as feature count,"
1573 "but the size of variance is %d "
1574 "and the feature count is %d.",
1575 ShapeUtil::GetDimension(var_shape, 0), feature_count);
1576 }
1577
1578 // Verify operand_shape and output_grad_shape have same bounds.
1579 for (int64 i = 0; i < operand_shape.rank(); ++i) {
1580 if (ShapeUtil::GetDimension(operand_shape, i) !=
1581 ShapeUtil::GetDimension(output_grad_shape, i)) {
1582 return InvalidArgument(
1583 "The bounds of operand shape should be the same as output_grad's,"
1584 "but the bound of operand_shape at dimension %d is %d "
1585 "and the bound of output_grad_shape is %d.",
1586 i, ShapeUtil::GetDimension(operand_shape, i),
1587 ShapeUtil::GetDimension(output_grad_shape, i));
1588 }
1589 }
1590
1591 return ShapeUtil::MakeTupleShape(
1592 {operand_shape, feature_shape, feature_shape});
1593 }
1594
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums)1595 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1596 const Shape& lhs, const Shape& rhs, int64 feature_group_count,
1597 int64 batch_group_count, const Window& window,
1598 const ConvolutionDimensionNumbers& dnums) {
1599 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1600 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1601
1602 if (feature_group_count <= 0) {
1603 return InvalidArgument(
1604 "feature_group_count must be a positive number, got %d",
1605 feature_group_count);
1606 }
1607
1608 if (batch_group_count <= 0) {
1609 return InvalidArgument(
1610 "batch_group_count must be a positive number, got %d",
1611 batch_group_count);
1612 }
1613
1614 if (batch_group_count > 1 && feature_group_count > 1) {
1615 return InvalidArgument(
1616 "both batch_group_count %d and feature_group_count %d cannot be "
1617 "greater than 1",
1618 batch_group_count, feature_group_count);
1619 }
1620
1621 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
1622 return InvalidArgument(
1623 "Convolution with different element types: %s and %s.",
1624 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
1625 }
1626 if (dnums.input_spatial_dimensions_size() !=
1627 dnums.kernel_spatial_dimensions_size()) {
1628 return InvalidArgument(
1629 "Both arguments to convolution must have same number of dimensions.\n"
1630 "Numbers: %s",
1631 dnums.DebugString());
1632 }
1633
1634 if (dnums.input_spatial_dimensions_size() !=
1635 dnums.output_spatial_dimensions_size()) {
1636 return InvalidArgument(
1637 "Both input and output of convolution must have same number of "
1638 "dimensions.\nNumbers: %s",
1639 dnums.DebugString());
1640 }
1641
1642 const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1643 if (window.dimensions_size() != num_spatial_dims) {
1644 return InvalidArgument(
1645 "Window must have same number of dimensions as dimension numbers.\n"
1646 "Window: %s\nDimension numbers: %s.",
1647 window.DebugString(), dnums.DebugString());
1648 }
1649
1650 const int num_dims = num_spatial_dims + 2;
1651 if (lhs.rank() != num_dims) {
1652 return InvalidArgument(
1653 "The LHS argument to a convolution should have rank %d; lhs: %s.",
1654 num_dims, ShapeUtil::HumanString(lhs));
1655 }
1656 if (rhs.rank() != num_dims) {
1657 return InvalidArgument(
1658 "The RHS argument to a convolution should have rank %d; rhs: %s.",
1659 num_dims, ShapeUtil::HumanString(rhs));
1660 }
1661 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1662 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1663
1664 // Verifies that the input and window dimensions are a permutation of
1665 // the dimension numbers.
1666 std::vector<int64> input_dnums(num_dims);
1667 input_dnums[0] = dnums.input_batch_dimension();
1668 input_dnums[1] = dnums.input_feature_dimension();
1669 std::copy(dnums.input_spatial_dimensions().begin(),
1670 dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1671 absl::c_sort(input_dnums);
1672
1673 std::vector<int64> window_dnums(num_dims);
1674 window_dnums[0] = dnums.kernel_input_feature_dimension();
1675 window_dnums[1] = dnums.kernel_output_feature_dimension();
1676 std::copy(dnums.kernel_spatial_dimensions().begin(),
1677 dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1678 absl::c_sort(window_dnums);
1679
1680 std::vector<int64> output_dnums(num_dims);
1681 output_dnums[0] = dnums.output_batch_dimension();
1682 output_dnums[1] = dnums.output_feature_dimension();
1683 std::copy(dnums.output_spatial_dimensions().begin(),
1684 dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1685 absl::c_sort(output_dnums);
1686
1687 std::vector<int64> expected_dnums(num_dims);
1688 std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1689
1690 const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1691 if (!absl::c_all_of(input_dnums, in_range) ||
1692 !absl::c_all_of(window_dnums, in_range) ||
1693 !absl::c_all_of(output_dnums, in_range)) {
1694 return InvalidArgument(
1695 "A dimension number is out of range in convolution: %s.",
1696 dnums.DebugString());
1697 }
1698
1699 if (input_dnums != expected_dnums) {
1700 return InvalidArgument(
1701 "Input dimensions of convolution must contain each dimension exactly "
1702 "once: %s.",
1703 dnums.DebugString());
1704 }
1705 if (window_dnums != expected_dnums) {
1706 return InvalidArgument(
1707 "Window dimensions of convolution must contain each dimension exactly "
1708 "once: %s.",
1709 dnums.DebugString());
1710 }
1711 if (output_dnums != expected_dnums) {
1712 return InvalidArgument(
1713 "Output dimensions of convolution must contain each dimension exactly "
1714 "once: %s.",
1715 dnums.DebugString());
1716 }
1717
1718 std::vector<int64> input_spatial_dims(num_spatial_dims);
1719 for (int i = 0; i < num_spatial_dims; ++i) {
1720 input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1721 }
1722 const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1723 const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1724
1725 std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1726 for (int i = 0; i < num_spatial_dims; ++i) {
1727 kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1728 }
1729 const int64 kernel_input_features =
1730 rhs.dimensions(dnums.kernel_input_feature_dimension());
1731 const int64 kernel_output_features =
1732 rhs.dimensions(dnums.kernel_output_feature_dimension());
1733
1734 if (kernel_output_features % batch_group_count != 0) {
1735 return InvalidArgument(
1736 "Expected output feature dimension size (value %d) to be a multiple of "
1737 "batch group count %d; got <conv>(%s, %s)\n"
1738 "Dimension numbers: {%s}.",
1739 kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1740 ShapeUtil::HumanString(rhs), dnums.DebugString());
1741 }
1742
1743 if (input_features % feature_group_count != 0 ||
1744 input_features / feature_group_count != kernel_input_features) {
1745 return InvalidArgument(
1746 "Expected LHS feature dimension (value %d) to be a multiple of "
1747 "feature_group_count (value %d), and LHS feature dimension / "
1748 "feature_group_count = RHS feature dimension (value %d); "
1749 "got <conv>(%s, %s)\n"
1750 "Dimension numbers: {%s}.",
1751 input_features, feature_group_count, kernel_input_features,
1752 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1753 dnums.DebugString());
1754 }
1755
1756 if (kernel_output_features % feature_group_count > 0) {
1757 // A depthwise/grouped filter has the shape
1758 // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1759 // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1760 // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1761 // feature count (which is equal to kernel output features) has to be a
1762 // multiple of feature_group_count.
1763 return InvalidArgument(
1764 "Expected output feature dimension (value %d) to be divisible by "
1765 "feature_group_count (value %d); "
1766 "got <conv>(%s, %s)\n"
1767 "Dimension numbers: {%s}.",
1768 kernel_output_features, feature_group_count,
1769 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1770 dnums.DebugString());
1771 }
1772
1773 if (input_batch % batch_group_count != 0) {
1774 return InvalidArgument(
1775 "Expected input batch dimension (value %d) to be divisible by "
1776 "batch_group_count (value %d); "
1777 "got <conv>(%s, %s)\n"
1778 "Dimension numbers: {%s}.",
1779 input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1780 ShapeUtil::HumanString(rhs), dnums.DebugString());
1781 }
1782
1783 std::vector<int64> window_dims(num_spatial_dims);
1784 for (int i = 0; i < num_spatial_dims; ++i) {
1785 window_dims[i] = window.dimensions(i).size();
1786 }
1787 if (kernel_spatial_dims != window_dims) {
1788 return InvalidArgument(
1789 "Window dimensions do not match RHS shape:\n\t"
1790 "RHS shape: %s\n\t"
1791 "Window: {%s}\n\t"
1792 "Dimension numbers: {%s}.",
1793 ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1794 dnums.ShortDebugString());
1795 }
1796
1797 Shape base_shape =
1798 ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1799 TF_ASSIGN_OR_RETURN(
1800 Shape window_output_shape,
1801 InferWindowOutputShape(base_shape, window, lhs.element_type(),
1802 /*allow_negative_padding=*/true));
1803
1804 std::vector<int64> dimensions(num_dims);
1805 dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1806 dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1807
1808 for (int i = 0; i < num_spatial_dims; ++i) {
1809 dimensions[dnums.output_spatial_dimensions(i)] =
1810 window_output_shape.dimensions(i);
1811 }
1812 std::vector<bool> is_dynamic(num_dims);
1813 for (int i = 0; i < num_dims; i++) {
1814 if (lhs.is_dynamic_dimension(i)) {
1815 if (i == dnums.input_batch_dimension()) {
1816 is_dynamic[dnums.output_batch_dimension()] = true;
1817 } else if (i == dnums.input_feature_dimension()) {
1818 // Input feature dimension is a contracting dimension, which does not
1819 // affect the output dimension size. So we need to do nothing.
1820 } else {
1821 return InvalidArgument(
1822 "Dynamic Spatial Convolution is not supported: lhs shape is %s ",
1823 lhs.ToString());
1824 }
1825 }
1826 if (rhs.is_dynamic_dimension(i)) {
1827 if (i == dnums.kernel_input_feature_dimension()) {
1828 // Kernel feature dimension does not affect the output dimension size.
1829 // So we need to do nothing.
1830 } else {
1831 return InvalidArgument(
1832 "Dynamic Spatial Convolution is not supported: rhs shape is %s ",
1833 rhs.ToString());
1834 }
1835 }
1836 }
1837 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1838 dimensions, is_dynamic);
1839 }
1840
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1841 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1842 const Shape& in, const FftType fft_type,
1843 const absl::Span<const int64> fft_length) {
1844 const int64 fft_rank = fft_length.size();
1845 if (fft_rank < 1 || fft_rank > 3) {
1846 return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1847 }
1848 #define RET_CHECK_RANK(x) \
1849 if (x.dimensions_size() < fft_rank) { \
1850 return InvalidArgument( \
1851 "FFT of rank %d requires input of at least " \
1852 "same rank; got input of rank %d", \
1853 fft_rank, x.dimensions_size()); \
1854 }
1855 switch (fft_type) {
1856 case FFT:
1857 case IFFT:
1858 if (in.element_type() != C64) {
1859 return InvalidArgument("%s requires complex input type, found %s.",
1860 FftType_Name(fft_type),
1861 PrimitiveType_Name(in.element_type()));
1862 }
1863 RET_CHECK_RANK(in);
1864 return in;
1865 case RFFT: {
1866 if (in.element_type() != F32) {
1867 return InvalidArgument("RFFT requires F32 input type, found %s.",
1868 PrimitiveType_Name(in.element_type()));
1869 }
1870 RET_CHECK_RANK(in);
1871 for (int i = 0; i < fft_rank; i++) {
1872 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1873 fft_length[i]) {
1874 return InvalidArgument(
1875 "RFFT requires innermost dimensions match fft_length but "
1876 "dimension %d is %d and should be %d.",
1877 in.dimensions_size() - fft_rank + i,
1878 in.dimensions(in.dimensions_size() - fft_rank + i),
1879 fft_length[i]);
1880 }
1881 }
1882 Shape result = ShapeUtil::ChangeElementType(in, C64);
1883 // Preserve the size of zero-sized dimensions.
1884 if (fft_length[fft_rank - 1] != 0) {
1885 result.set_dimensions(result.dimensions_size() - 1,
1886 fft_length[fft_rank - 1] / 2 + 1);
1887 }
1888 return result;
1889 }
1890 case IRFFT: {
1891 if (in.element_type() != C64) {
1892 return InvalidArgument("IRFFT requires C64 input type, found %s.",
1893 PrimitiveType_Name(in.element_type()));
1894 }
1895 RET_CHECK_RANK(in);
1896 Shape result = ShapeUtil::ComplexComponentShape(in);
1897 for (int i = 0; i < fft_rank - 1; i++) {
1898 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1899 fft_length[i]) {
1900 return InvalidArgument(
1901 "IRFFT requires all but one innermost dimensions match "
1902 "fft_length, but dimension %d is %d and should be %d.",
1903 in.dimensions_size() - fft_rank + i,
1904 in.dimensions(in.dimensions_size() - fft_rank + i),
1905 fft_length[i]);
1906 }
1907 }
1908 // The size of zero-sized dimensions is preserved.
1909 if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1910 fft_length[fft_rank - 1] != 0) &&
1911 in.dimensions(in.dimensions_size() - 1) !=
1912 fft_length[fft_rank - 1] / 2 + 1) {
1913 return InvalidArgument(
1914 "IRFFT requires innermost dimension matches fft_length/2+1, but "
1915 "dimension %d is %d and should be %d.",
1916 in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1917 fft_length[fft_rank - 1] / 2 + 1);
1918 }
1919 result.set_dimensions(result.dimensions_size() - 1,
1920 fft_length[fft_rank - 1]);
1921 return result;
1922 }
1923 default:
1924 LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1925 }
1926 #undef RET_CHECK_RANK
1927 }
1928
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1929 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1930 const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1931 if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1932 a.element_type() != b.element_type()) {
1933 return InvalidArgument(
1934 "Expected element types in shape to be floating or complex and "
1935 "identical for TriangularSolve; got %s and %s.",
1936 PrimitiveType_Name(a.element_type()),
1937 PrimitiveType_Name(b.element_type()));
1938 }
1939 if (a.rank() < 2) {
1940 return InvalidArgument(
1941 "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1942 a.ToString());
1943 }
1944 if (b.rank() != a.rank()) {
1945 return InvalidArgument(
1946 "Arguments to triangular solve must have equal rank; got %s and %s.",
1947 b.ToString(), a.ToString());
1948 }
1949 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1950 return InvalidArgument(
1951 "The two minor dimensions of 'a' must have equal size, got %s.",
1952 a.ToString());
1953 }
1954 if (a.dimensions(a.rank() - 1) !=
1955 b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1956 return InvalidArgument(
1957 "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1958 "%s",
1959 a.ToString(), b.ToString());
1960 }
1961 absl::Span<const int64> a_batch_dims(a.dimensions());
1962 absl::Span<const int64> b_batch_dims(b.dimensions());
1963 a_batch_dims.remove_suffix(2);
1964 b_batch_dims.remove_suffix(2);
1965 if (a_batch_dims != b_batch_dims) {
1966 return InvalidArgument(
1967 "The leading batch dimensions of the arguments to triangular solve "
1968 "must be equal; got %s and %s.",
1969 b.ToString(), a.ToString());
1970 }
1971 if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
1972 options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
1973 return InvalidArgument(
1974 "Invalid transpose option value for triangular solve (%d).\n",
1975 options.transpose_a());
1976 }
1977 return b;
1978 }
1979
InferCholeskyShape(const Shape & a)1980 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
1981 const Shape& a) {
1982 if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
1983 return InvalidArgument(
1984 "Expected element type in shape to be floating or complex for "
1985 "Cholesky; got %s.",
1986 PrimitiveType_Name(a.element_type()));
1987 }
1988 if (a.rank() < 2) {
1989 return InvalidArgument(
1990 "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
1991 a.ToString());
1992 }
1993 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1994 return InvalidArgument(
1995 "The two minor dimensions of 'a' must have equal size, got %s.",
1996 a.ToString());
1997 }
1998 return a;
1999 }
2000
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2001 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2002 absl::Span<const Shape* const> operand_shapes) {
2003 for (const Shape* operand_shape : operand_shapes) {
2004 TF_RETURN_IF_ERROR(
2005 ExpectArray(*operand_shape, "operand of cross replica sum"));
2006 }
2007 if (operand_shapes.size() == 1) {
2008 return *operand_shapes[0];
2009 }
2010 std::vector<Shape> operand_shape_values;
2011 for (const Shape* operand_shape : operand_shapes) {
2012 operand_shape_values.push_back(*operand_shape);
2013 }
2014 return ShapeUtil::MakeTupleShape(operand_shape_values);
2015 }
2016
InferAllToAllShape(const Shape & shape,int64 split_dimension,int64 concat_dimension,int64 split_count)2017 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2018 const Shape& shape, int64 split_dimension, int64 concat_dimension,
2019 int64 split_count) {
2020 TF_RET_CHECK(split_count > 0);
2021 if (split_dimension >= shape.rank() || split_dimension < 0) {
2022 return InvalidArgument(
2023 "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2024 split_dimension, ShapeUtil::HumanString(shape));
2025 }
2026 if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2027 return InvalidArgument(
2028 "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2029 concat_dimension, ShapeUtil::HumanString(shape));
2030 }
2031 if (shape.dimensions(split_dimension) % split_count != 0) {
2032 return InvalidArgument(
2033 "AllToAll split dimension size %d must be dividable by split_count "
2034 "%d.",
2035 shape.dimensions(split_dimension), split_count);
2036 }
2037 std::vector<int64> new_dimensions(shape.dimensions().begin(),
2038 shape.dimensions().end());
2039 new_dimensions[split_dimension] /= split_count;
2040 new_dimensions[concat_dimension] *= split_count;
2041 return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2042 }
2043
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2044 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2045 absl::Span<const Shape* const> operand_shapes) {
2046 // An Alltoall HLO instruction receives N operands (with the same shape) and
2047 // returns a tuple that contains N array shapes.
2048 TF_RET_CHECK(!operand_shapes.empty());
2049 for (int i = 0; i < operand_shapes.size(); i++) {
2050 if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) {
2051 return InvalidArgument(
2052 "HLO all-to-all has operands with different shapes: the 0th "
2053 "operand shape %s, but the %dth operand has shape %s.",
2054 ShapeUtil::HumanString(*operand_shapes[0]), i,
2055 ShapeUtil::HumanString(*operand_shapes[i]));
2056 }
2057 }
2058
2059 return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2060 }
2061
InferCollectivePermuteShape(const Shape & shape)2062 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2063 const Shape& shape) {
2064 TF_RET_CHECK(shape.IsArray());
2065 return shape;
2066 }
2067
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2068 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2069 absl::Span<const Shape* const> arg_shapes,
2070 absl::Span<const int64> dimensions_to_reduce,
2071 const ProgramShape& to_apply) {
2072 if (arg_shapes.empty()) {
2073 return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2074 }
2075 if (arg_shapes.size() % 2) {
2076 return InvalidArgument(
2077 "Reduce must have an even number of arguments, has %lu",
2078 arg_shapes.size());
2079 }
2080 int64 num_reduced_args = arg_shapes.size() / 2;
2081
2082 auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2083 // Check that all of the reduced tensors have the same dimensions. The element
2084 // types may be different.
2085 for (int64 i = 1; i < num_reduced_args; ++i) {
2086 if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2087 return InvalidArgument(
2088 "All reduced tensors must have the same dimension. Tensor 0 has "
2089 "shape %s, Tensor %d has shape %s",
2090 ShapeUtil::HumanString(*reduced_args[0]), i,
2091 ShapeUtil::HumanString(*reduced_args[i]));
2092 }
2093 }
2094
2095 // Check that the dimensions to reduce are in-bounds for the given shape.
2096 // We've already verified all reduced tensors have the same dimensions, so it
2097 // doesn't matter which one we choose.
2098 const Shape& arg = *reduced_args[0];
2099 for (int64 dimension : dimensions_to_reduce) {
2100 if (dimension >= arg.rank() || dimension < 0) {
2101 return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2102 dimension, ShapeUtil::HumanString(arg));
2103 }
2104 }
2105
2106 auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2107 std::vector<PrimitiveType> element_types;
2108 for (const Shape* arg : reduced_args) {
2109 element_types.push_back(arg->element_type());
2110 }
2111 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2112 num_reduced_args));
2113
2114 std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
2115 dimensions_to_reduce.end());
2116 std::vector<int64> new_dimensions;
2117 std::vector<bool> new_is_dynamic;
2118 for (int i = 0; i < arg.rank(); ++i) {
2119 if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2120 new_dimensions.push_back(arg.dimensions(i));
2121 new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2122 }
2123 }
2124
2125 if (ShapeUtil::IsScalar(to_apply.result())) {
2126 return ShapeUtil::MakeShape(to_apply.result().element_type(),
2127 new_dimensions, new_is_dynamic);
2128 } else {
2129 std::vector<Shape> result_subshapes;
2130 for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2131 result_subshapes.push_back(ShapeUtil::MakeShape(
2132 subshape.element_type(), new_dimensions, new_is_dynamic));
2133 }
2134 return ShapeUtil::MakeTupleShape(result_subshapes);
2135 }
2136 }
2137
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2138 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2139 const Shape& operand_shape, const Shape& init_value_shape,
2140 const Window& window, const ProgramShape& to_apply_shape) {
2141 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2142 {operand_shape.element_type()},
2143 /*inputs=*/1));
2144 return InferReduceWindowShape(operand_shape, init_value_shape, window);
2145 }
2146
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2147 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2148 const Shape& operand_shape, const Shape& init_value_shape,
2149 const Window& window) {
2150 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2151 return InferWindowOutputShape(operand_shape, window,
2152 init_value_shape.element_type(),
2153 /*allow_negative_padding=*/false);
2154 }
2155
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2156 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2157 const Shape& operand_shape, const ProgramShape& select_shape,
2158 const Window& window, const Shape& source_shape,
2159 const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2160 TF_RETURN_IF_ERROR(
2161 ExpectArray(operand_shape, "operand of select-and-scatter"));
2162
2163 // Check if the select function has a proper shape of (T,T) -> PRED.
2164 if (select_shape.parameters_size() != 2) {
2165 return InvalidArgument(
2166 "Select function must take 2 parameters, but "
2167 "takes %d parameter(s).",
2168 select_shape.parameters_size());
2169 }
2170 const Shape& select_result_shape = select_shape.result();
2171 if (!ShapeUtil::Compatible(select_result_shape,
2172 ShapeUtil::MakeShape(PRED, {}))) {
2173 return InvalidArgument("Select function must have rank-0 PRED result.");
2174 }
2175 const Shape& operand_element_shape =
2176 ShapeUtil::MakeShape(operand_shape.element_type(), {});
2177 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2178 select_shape.parameters(0))) {
2179 return InvalidArgument(
2180 "Select function's first parameter shape currently must "
2181 "match the operand element shape, but got %s vs %s.",
2182 ShapeUtil::HumanString(select_shape.parameters(0)),
2183 ShapeUtil::HumanString(operand_element_shape));
2184 }
2185 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2186 select_shape.parameters(1))) {
2187 return InvalidArgument(
2188 "Select function's second parameter shape currently must "
2189 "match the operand element shape, but got %s vs %s.",
2190 ShapeUtil::HumanString(select_shape.parameters(1)),
2191 ShapeUtil::HumanString(operand_element_shape));
2192 }
2193
2194 // Check if the scatter function has a proper shape as a reduction.
2195 TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2196 {source_shape.element_type()},
2197 /*inputs=*/1));
2198
2199 // Check if the result shape of window operation matches the source shape.
2200 TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2201 InferWindowOutputShape(operand_shape, window,
2202 operand_shape.element_type(),
2203 /*allow_negative_padding=*/false));
2204 if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2205 window_result_shape)) {
2206 return InvalidArgument(
2207 "Source shape does not match the shape of window-reduced operand: "
2208 "source(%s), window-reduced operand(%s).",
2209 ShapeUtil::HumanString(source_shape),
2210 ShapeUtil::HumanString(window_result_shape));
2211 }
2212
2213 return operand_shape;
2214 }
2215
InferGetDimensionSizeShape(const Shape & shape,int64 dimension)2216 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2217 const Shape& shape, int64 dimension) {
2218 if (dimension < 0 || dimension >= shape.rank()) {
2219 return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2220 dimension);
2221 }
2222
2223 // TODO(b/119580730): Remove this restriction when very large dimension size
2224 // is needed.
2225 if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2226 return InvalidArgument(
2227 "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2228 "INT_MAX limit.",
2229 ShapeUtil::HumanString(shape), dimension);
2230 }
2231
2232 return ShapeUtil::MakeShape(S32, {});
2233 }
2234
InferSetDimensionSizeShape(const Shape & shape,int64 dimension)2235 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2236 const Shape& shape, int64 dimension) {
2237 if (dimension < 0 || dimension >= shape.rank()) {
2238 return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2239 dimension);
2240 }
2241
2242 // TODO(b/119580730): Remove this restriction when very large dimension size
2243 // is needed.
2244 if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
2245 return InvalidArgument(
2246 "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2247 "INT_MAX limit.",
2248 ShapeUtil::HumanString(shape), dimension);
2249 }
2250
2251 Shape result = shape;
2252 result.set_dynamic_dimension(dimension, true);
2253 return result;
2254 }
2255
InferWindowFromDimensions(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation)2256 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2257 absl::Span<const int64> window_dimensions,
2258 absl::Span<const int64> window_strides,
2259 absl::Span<const std::pair<int64, int64>> padding,
2260 absl::Span<const int64> lhs_dilation,
2261 absl::Span<const int64> rhs_dilation) {
2262 const auto verify_size = [&](const size_t x, const char* x_name) {
2263 if (x == 0 || x == window_dimensions.size()) {
2264 return Status::OK();
2265 } else {
2266 return InvalidArgument(
2267 "%s", absl::StrCat(
2268 "Window has different number of window dimensions than of ",
2269 x_name,
2270 "\nNumber of window dimensions: ", window_dimensions.size(),
2271 "\nNumber of ", x_name, ": ", x, "\n"));
2272 }
2273 };
2274 TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2275 TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2276 TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2277 TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2278
2279 Window window;
2280 for (size_t i = 0; i < window_dimensions.size(); i++) {
2281 auto dim = window.add_dimensions();
2282 dim->set_size(window_dimensions[i]);
2283 if (!window_strides.empty()) {
2284 dim->set_stride(window_strides[i]);
2285 } else {
2286 dim->set_stride(1);
2287 }
2288 if (!padding.empty()) {
2289 dim->set_padding_low(padding[i].first);
2290 dim->set_padding_high(padding[i].second);
2291 } else {
2292 dim->set_padding_low(0);
2293 dim->set_padding_high(0);
2294 }
2295 if (!lhs_dilation.empty()) {
2296 dim->set_base_dilation(lhs_dilation[i]);
2297 } else {
2298 dim->set_base_dilation(1);
2299 }
2300 if (!rhs_dilation.empty()) {
2301 dim->set_window_dilation(rhs_dilation[i]);
2302 } else {
2303 dim->set_window_dilation(1);
2304 }
2305 dim->set_window_reversal(false);
2306 }
2307 return window;
2308 }
2309
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2310 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2311 const Shape& arg, absl::Span<const int64> starts,
2312 absl::Span<const int64> limits, absl::Span<const int64> strides) {
2313 auto error = [&](const string& message) {
2314 return InvalidArgument(
2315 "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2316 "{%s}; strides: {%s}.",
2317 message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2318 StrJoin(limits, ","), StrJoin(strides, ","));
2319 };
2320 TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2321 VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2322 ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2323 StrJoin(limits, ", "));
2324
2325 if (starts.size() != limits.size()) {
2326 return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2327 starts.size(), limits.size()));
2328 }
2329
2330 if (starts.size() != strides.size()) {
2331 return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2332 starts.size(), strides.size()));
2333 }
2334
2335 if (starts.size() != arg.rank()) {
2336 return InvalidArgument(
2337 "Slice index count does not match argument rank: %u vs %d.",
2338 starts.size(), arg.rank());
2339 }
2340
2341 std::vector<int64> sizes;
2342 for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
2343 int64 start_index = starts[dimension];
2344 int64 limit_index = limits[dimension];
2345 int64 stride = strides[dimension];
2346 if (start_index < 0) {
2347 return InvalidArgument("Negative start index to slice: %d.", start_index);
2348 }
2349 if (limit_index > arg.dimensions(dimension)) {
2350 return error(
2351 StrFormat("limit index (%d) must be less than or equal to dimension "
2352 "size (%d)",
2353 limit_index, arg.dimensions(dimension)));
2354 }
2355 VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2356 VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2357 if (start_index > limit_index) {
2358 return error(
2359 StrFormat("limit index (%d) must be greater or equal to "
2360 "start index (%d) in slice with positive stride",
2361 limit_index, start_index));
2362 }
2363 if (stride <= 0) {
2364 return InvalidArgument("Stride (%d) must be positive.", stride);
2365 }
2366 sizes.push_back((limit_index - start_index + stride - 1) / stride);
2367 }
2368
2369 std::vector<bool> is_dynamic(arg.rank());
2370 for (int64 i = 0; i < arg.dimensions_size(); ++i) {
2371 // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2372 if (sizes[i] == 1) {
2373 continue;
2374 }
2375 is_dynamic[i] = arg.is_dynamic_dimension(i);
2376 }
2377
2378 return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2379 }
2380
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2381 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2382 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2383 absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2384 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2385 auto number_of_indices = start_index_shapes.size();
2386 // TODO(b/118437727): Remove this path.
2387 if (!allow_scalar_indices ||
2388 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2389 if (number_of_indices != 1) {
2390 return InvalidArgument(
2391 "Dynamic slice should have exactly 1 index operand, has %d.",
2392 number_of_indices);
2393 }
2394
2395 const Shape& start_indices_shape = start_index_shapes[0];
2396 VLOG(2) << StrFormat(
2397 "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2398 ShapeUtil::HumanString(operand_shape),
2399 ShapeUtil::HumanString(start_indices_shape),
2400 StrJoin(slice_sizes, ", "));
2401
2402 TF_RETURN_IF_ERROR(
2403 ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2404
2405 if (start_indices_shape.rank() != 1) {
2406 return InvalidArgument(
2407 "Dynamic slice start indices of rank %d must be rank1.",
2408 start_indices_shape.rank());
2409 }
2410
2411 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2412 return InvalidArgument(
2413 "Dynamic slice start indices must be of integral type.");
2414 }
2415
2416 const int64 start_num_dims = start_indices_shape.dimensions(0);
2417 if (operand_shape.rank() != start_num_dims) {
2418 return InvalidArgument(
2419 "Dynamic slice start number of dimensions %d (%s) must match rank "
2420 "%d of slice input (%s).",
2421 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2422 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2423 }
2424 } else {
2425 VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2426 ShapeUtil::HumanString(operand_shape),
2427 StrJoin(slice_sizes, ", "));
2428
2429 if (operand_shape.rank() != number_of_indices) {
2430 return InvalidArgument(
2431 "Dynamic slice start number of dimensions %d must match rank "
2432 "%d of slice input (%s).",
2433 number_of_indices, operand_shape.rank(),
2434 ShapeUtil::HumanString(operand_shape));
2435 }
2436
2437 if (number_of_indices > 0) {
2438 const Shape& first_index_shape = start_index_shapes[0];
2439 if (!ShapeUtil::IsScalar(first_index_shape)) {
2440 return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2441 ShapeUtil::HumanString(first_index_shape));
2442 }
2443 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2444 return InvalidArgument(
2445 "Dynamic slice start indices must be of integral type.");
2446 }
2447 for (const Shape& index_shape : start_index_shapes) {
2448 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2449 return InvalidArgument(
2450 "Dynamic slice start indices must all have the same shape, got "
2451 "mismatching indices with shapes %s and %s.",
2452 ShapeUtil::HumanString(first_index_shape),
2453 ShapeUtil::HumanString(index_shape));
2454 }
2455 }
2456 }
2457 }
2458
2459 if (slice_sizes.size() != operand_shape.rank()) {
2460 return InvalidArgument(
2461 "Dynamic slice index count does not match argument rank: %u vs %d.",
2462 slice_sizes.size(), operand_shape.rank());
2463 }
2464
2465 for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2466 const int64 input_dim_size = operand_shape.dimensions(dim);
2467 const int64 slice_dim_size = slice_sizes[dim];
2468 if (slice_dim_size < 0) {
2469 return InvalidArgument("Negative size index to dynamic slice: %d.",
2470 slice_dim_size);
2471 }
2472 if (slice_dim_size > input_dim_size) {
2473 return InvalidArgument(
2474 "Slice dim size %d greater than dynamic slice dimension: %d.",
2475 slice_dim_size, input_dim_size);
2476 }
2477 VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2478 }
2479
2480 return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2481 }
2482
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2483 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2484 const Shape& operand_shape, const Shape& update_shape,
2485 absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2486 TF_RETURN_IF_ERROR(
2487 ExpectArray(operand_shape, "operand of dynamic update slice"));
2488 TF_RETURN_IF_ERROR(
2489 ExpectArray(update_shape, "update of dynamic update slice"));
2490
2491 auto number_of_indices = start_index_shapes.size();
2492 // TODO(b/118437727): Remove this path.
2493 if (!allow_scalar_indices ||
2494 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2495 if (number_of_indices != 1) {
2496 return InvalidArgument(
2497 "Dynamic update slice should have exactly 1 index operand, has %d.",
2498 number_of_indices);
2499 }
2500 const Shape& start_indices_shape = start_index_shapes[0];
2501 TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2502 "start indices of dynamic update slice"));
2503
2504 VLOG(2) << StrFormat(
2505 "updating slice of shape %s at dynamic start_indices %s with update "
2506 "shape %s",
2507 ShapeUtil::HumanString(operand_shape),
2508 ShapeUtil::HumanString(start_indices_shape),
2509 ShapeUtil::HumanString(update_shape));
2510
2511 if (start_indices_shape.rank() != 1) {
2512 return InvalidArgument(
2513 "Dynamic update slice start indices of rank %d must be rank1.",
2514 start_indices_shape.rank());
2515 }
2516
2517 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2518 return InvalidArgument(
2519 "Dynamic update slice start indices must be of integral type.");
2520 }
2521
2522 const int64 start_num_dims = start_indices_shape.dimensions(0);
2523 if (operand_shape.rank() != start_num_dims) {
2524 return InvalidArgument(
2525 "Dynamic update slice start number of dimensions %d (%s) must match "
2526 "rank %d of slice input (%s).",
2527 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2528 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2529 }
2530 } else {
2531 VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2532 ShapeUtil::HumanString(operand_shape),
2533 ShapeUtil::HumanString(update_shape));
2534
2535 if (operand_shape.rank() != number_of_indices) {
2536 return InvalidArgument(
2537 "Dynamic update slice start number of dimensions %d must match "
2538 "rank %d of slice input (%s).",
2539 number_of_indices, operand_shape.rank(),
2540 ShapeUtil::HumanString(operand_shape));
2541 }
2542
2543 if (number_of_indices > 0) {
2544 const Shape& first_index_shape = start_index_shapes[0];
2545 if (!ShapeUtil::IsScalar(first_index_shape)) {
2546 return InvalidArgument(
2547 "Dynamic update slice indices must be scalar, not %s.",
2548 ShapeUtil::HumanString(first_index_shape));
2549 }
2550 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2551 return InvalidArgument(
2552 "Dynamic update slice start indices must be of integral type.");
2553 }
2554 for (const Shape& index_shape : start_index_shapes) {
2555 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2556 return InvalidArgument(
2557 "Dynamic update slice start indices must all have the same "
2558 "shape, got mismatching indices with shapes %s and %s.",
2559 ShapeUtil::HumanString(first_index_shape),
2560 ShapeUtil::HumanString(index_shape));
2561 }
2562 }
2563 }
2564 }
2565
2566 if (update_shape.rank() != operand_shape.rank()) {
2567 return InvalidArgument(
2568 "Dynamic update slice update rank does not match argument rank: "
2569 "%d vs %d.",
2570 update_shape.rank(), operand_shape.rank());
2571 }
2572
2573 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2574 update_shape)) {
2575 return InvalidArgument(
2576 "Dynamic update slice update element type does not match argument. "
2577 "operand.element_type: %s vs update.element_type: %s.",
2578 PrimitiveType_Name(operand_shape.element_type()),
2579 PrimitiveType_Name(update_shape.element_type()));
2580 }
2581
2582 for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
2583 const int64 input_dim_size = operand_shape.dimensions(dim);
2584 const int64 update_dim_size = update_shape.dimensions(dim);
2585 if (update_dim_size < 0) {
2586 return InvalidArgument(
2587 "Size index %d to dynamic update slice must be >= 0.",
2588 update_dim_size);
2589 }
2590 if (update_dim_size > input_dim_size) {
2591 return InvalidArgument(
2592 "Update dim size %d greater than dynamic slice dimension: %d.",
2593 update_dim_size, input_dim_size);
2594 }
2595 VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2596 }
2597
2598 return operand_shape;
2599 }
2600
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2601 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2602 const Shape& operand_shape, absl::Span<const int64> dimensions) {
2603 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2604 if (!AllUnique(dimensions)) {
2605 return InvalidArgument("a dimension number is duplicated in reverse");
2606 }
2607 for (int64 dimension : dimensions) {
2608 if (dimension >= operand_shape.rank() || dimension < 0) {
2609 return InvalidArgument(
2610 "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2611 dimension, ShapeUtil::HumanString(operand_shape));
2612 }
2613 }
2614 return operand_shape;
2615 }
2616
InferGetTupleElementShape(const Shape & arg,int64 index)2617 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2618 const Shape& arg, int64 index) {
2619 if (!arg.IsTuple()) {
2620 return InvalidArgument(
2621 "Cannot infer shape: attempting to index into non-tuple: %s.",
2622 ShapeUtil::HumanString(arg));
2623 }
2624
2625 if (index < 0 || index >= arg.tuple_shapes_size()) {
2626 return InvalidArgument(
2627 "Cannot infer shape: attempt to index out of tuple bounds: %d "
2628 ">= %d in shape %s.",
2629 index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2630 }
2631
2632 return arg.tuple_shapes(index);
2633 }
2634
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2635 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2636 const ProgramShape& condition, const ProgramShape& body,
2637 const Shape& init) {
2638 // Check the number of parameters for given computations.
2639 if (condition.parameters_size() != 1) {
2640 return InvalidArgument("Condition must take 1 arguments; got %d.",
2641 condition.parameters_size());
2642 }
2643 if (body.parameters_size() != 1) {
2644 return InvalidArgument("Body must take 1 arguments; got %d.",
2645 body.parameters_size());
2646 }
2647
2648 auto shape_string = [&]() {
2649 return StrFormat(
2650 "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2651 ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2652 };
2653
2654 // Check the shapes of computation parameters and return types.
2655 if (!ShapeUtil::Compatible(condition.result(),
2656 ShapeUtil::MakeShape(PRED, {}))) {
2657 return InvalidArgument("Condition must return a boolean; got %s.",
2658 shape_string());
2659 }
2660 if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2661 !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2662 !ShapeUtil::Compatible(body.result(), init)) {
2663 return InvalidArgument(
2664 "The parameter of condition and body, the result of the body, and init "
2665 "must all have the same shape; got %s.",
2666 shape_string());
2667 }
2668
2669 return init;
2670 }
2671
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2672 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2673 const Shape& branch_index,
2674 absl::Span<const ProgramShape> branch_computations,
2675 absl::Span<const Shape> branch_operands) {
2676 if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2677 !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2678 return InvalidArgument("branch_index must be bool or int32; got %s.",
2679 ShapeUtil::HumanString(branch_index));
2680 }
2681 if (branch_index.element_type() == PRED) {
2682 TF_RET_CHECK(2 == branch_computations.size());
2683 } else {
2684 TF_RET_CHECK(!branch_computations.empty());
2685 }
2686 TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2687
2688 for (int j = 0; j < branch_computations.size(); ++j) {
2689 if (branch_computations[j].parameters_size() != 1) {
2690 return InvalidArgument(
2691 "branch computation %d must take 1 argument; got %d.", j,
2692 branch_computations[j].parameters_size());
2693 }
2694 if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2695 branch_operands[j])) {
2696 auto shape_string = [&]() {
2697 return StrFormat("operand: %s; computation: %s",
2698 ShapeUtil::HumanString(branch_operands[j]),
2699 ShapeUtil::HumanString(branch_computations[j]));
2700 };
2701 return InvalidArgument(
2702 "branch operand %d must match the shape of the only parameter of "
2703 "branch computation %d: got %s.",
2704 j, j, shape_string());
2705 }
2706
2707 if (!ShapeUtil::Compatible(branch_computations[0].result(),
2708 branch_computations[j].result())) {
2709 auto shape_string = [&]() {
2710 return StrFormat(
2711 "branch 0 computation result: %s; branch %d computation result: %s",
2712 ShapeUtil::HumanString(branch_computations[0].result()), j,
2713 ShapeUtil::HumanString(branch_computations[j].result()));
2714 };
2715 return InvalidArgument(
2716 "the result of branch 0 computation and branch %d computation must "
2717 "have the same shape: got %s.",
2718 j, shape_string());
2719 }
2720 }
2721 return branch_computations[0].result();
2722 }
2723
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2724 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2725 const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2726 TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2727 for (int64 size : broadcast_sizes) {
2728 if (size < 0) {
2729 return InvalidArgument("Broadcast with negative dimension size %d.",
2730 size);
2731 }
2732 }
2733
2734 std::vector<int64> dimensions(operand.dimensions_size() +
2735 broadcast_sizes.size());
2736 std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2737 std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2738 dimensions.begin() + broadcast_sizes.size());
2739
2740 Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2741 for (int64 i = 0; i < operand.dimensions_size(); ++i) {
2742 result.set_dynamic_dimension(broadcast_sizes.size() + i,
2743 operand.is_dynamic_dimension(i));
2744 }
2745 return result;
2746 }
2747
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2748 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2749 const Shape& operand_shape, const Shape& output_shape,
2750 absl::Span<const int64> broadcast_dimensions) {
2751 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
2752 TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
2753 const int64 operand_rank = operand_shape.rank();
2754 const int64 output_rank = output_shape.rank();
2755 if (operand_rank > output_rank) {
2756 return InvalidArgument(
2757 "InDim style broadcast must be to an equal or higher ranked shape; "
2758 "operand rank: %lld; output rank: %lld",
2759 operand_rank, output_rank);
2760 }
2761 if (operand_rank != broadcast_dimensions.size()) {
2762 return InvalidArgument(
2763 "Size of broadcast_dimensions has to match operand's rank; operand "
2764 "rank: %lld, size of broadcast_dimensions %u.",
2765 operand_rank, broadcast_dimensions.size());
2766 }
2767 for (int64 i = 0; i < operand_rank; i++) {
2768 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
2769 return InvalidArgument("Broadcast dimension %lld is out of bound",
2770 broadcast_dimensions[i]);
2771 }
2772 if (operand_shape.dimensions(i) !=
2773 output_shape.dimensions(broadcast_dimensions[i]) &&
2774 operand_shape.dimensions(i) != 1) {
2775 return InvalidArgument(
2776 "Input dimension should be either 1 or equal to the output dimension "
2777 "it is broadcasting into; the %lldth operand dimension is %lld, the "
2778 "%lldth output dimension is %lld.",
2779 i, operand_shape.dimensions(i), broadcast_dimensions[i],
2780 output_shape.dimensions(broadcast_dimensions[i]));
2781 }
2782 if (operand_shape.is_dynamic_dimension(i) !=
2783 output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
2784 return InvalidArgument(
2785 "Broadcast input and output dynamism mismatch: %s and %s",
2786 operand_shape.ToString(), output_shape.ToString());
2787 }
2788 // Make sure the broadcast dimensions are listed in a strictly increasing
2789 // order.
2790 if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
2791 return InvalidArgument(
2792 "Broadcast dimensions order is wrong: %d comes after %d.",
2793 broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
2794 }
2795 }
2796
2797 return output_shape;
2798 }
2799
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes,int64 inferred_dimension)2800 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2801 const Shape& operand, absl::Span<const int64> dimensions,
2802 absl::Span<const int64> new_sizes, int64 inferred_dimension) {
2803 TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
2804
2805 Shape inferred_shape =
2806 ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2807 VLOG(3) << "Reshape inferred shape: "
2808 << ShapeUtil::HumanString(inferred_shape);
2809
2810 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2811 return InvalidArgument(
2812 "Reshape operation has mismatched element counts: from=%d (%s) "
2813 "to=%d (%s).",
2814 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2815 ShapeUtil::ElementsIn(inferred_shape),
2816 ShapeUtil::HumanString(inferred_shape));
2817 }
2818
2819 std::vector<int64> indices(operand.rank());
2820 std::iota(indices.begin(), indices.end(), 0);
2821 if (dimensions.size() != operand.rank() ||
2822 !std::is_permutation(dimensions.begin(), dimensions.end(),
2823 indices.begin())) {
2824 return InvalidArgument(
2825 "Reshape dimensions [%s] are not a permutation of the operand "
2826 "dimensions (operand shape is %s).",
2827 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2828 }
2829
2830 // Propagate dynamic dimension.
2831 auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
2832 for (int64 input_dim = 0; input_dim < operand.rank(); ++input_dim) {
2833 if (!operand.is_dynamic_dimension(input_dim)) {
2834 continue;
2835 }
2836
2837 string reshape_debug_str = absl::StrFormat(
2838 "output: %s, input: %s, input_dim: "
2839 "%lld",
2840 ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
2841 input_dim);
2842
2843 int64 input_dim_start = -1;
2844 int64 input_dim_end = -1;
2845 int64 output_dim_start = -1;
2846 int64 output_dim_end = -1;
2847 // Find common_factors that the input_dim belongs to.
2848 for (int64 i = 0; i < common_factors.size() - 1; ++i) {
2849 auto start = common_factors[i];
2850 auto end = common_factors[i + 1];
2851 if (input_dim >= start.first && input_dim < end.first) {
2852 input_dim_start = start.first;
2853 input_dim_end = end.first;
2854 output_dim_start = start.second;
2855 output_dim_end = end.second;
2856 break;
2857 }
2858 }
2859 if ((input_dim_end - input_dim_start) > 1 &&
2860 (output_dim_end - output_dim_start) > 1) {
2861 // We don't support the case when a dynamic dimension is both combined
2862 // with and splitted into other dimensions:
2863 //
2864 // [x, yz]
2865 // | Reshape
2866 // [xy, z]
2867 return Unimplemented(
2868 "Dynamic input dimension to reshape that is both splitted and "
2869 "combined is not supported: %s",
2870 reshape_debug_str);
2871 }
2872
2873 for (auto common_factor : common_factors) {
2874 //
2875 // For reshapes like:
2876 // [<=5]
2877 // | Reshape
2878 // [1, 5]
2879 //
2880 // The input dynamic dimension can go into either dynamic dimensions.
2881 // However, the return value of common factors only returns
2882 // input: 5
2883 // output: 5
2884 //
2885 // We need to expand common factor to include degenerated output
2886 // dimensions:
2887 // input: 5
2888 // output: 1, 5
2889 //
2890 // such that in the logic later on we can consider both dimensions as
2891 // candidate.
2892 if (common_factor.first == input_dim_start) {
2893 output_dim_start = std::min(output_dim_start, common_factor.second);
2894 }
2895 if (common_factor.first == input_dim_end) {
2896 output_dim_end = std::max(output_dim_end, common_factor.second);
2897 }
2898 }
2899
2900 // Calculate output dynamic reshape dimension.
2901 int64 output_dynamic_dimension = -1;
2902
2903 if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
2904 // If dynamic dimension is size 1, it can only be most-major or
2905 // most-minor.
2906 if (input_dim == 0) {
2907 output_dynamic_dimension = 0;
2908 }
2909 if (input_dim == operand.rank() - 1) {
2910 output_dynamic_dimension = new_sizes.size() - 1;
2911 }
2912
2913 if (output_dynamic_dimension == -1) {
2914 return Unimplemented(
2915 "Dynamic degenerated dimension that's not most-minor nor "
2916 "most-major is not supported: %s",
2917 reshape_debug_str);
2918 }
2919 }
2920
2921 if (output_dynamic_dimension == -1 &&
2922 output_dim_end - output_dim_start == 1) {
2923 // Only one possible output dimension.
2924 output_dynamic_dimension = output_dim_start;
2925 }
2926 if (output_dynamic_dimension == -1 &&
2927 output_dim_end - output_dim_start > 1) {
2928 // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
2929 output_dynamic_dimension = inferred_dimension;
2930 }
2931
2932 if (output_dynamic_dimension != -1) {
2933 // TODO(yunxing): Turn this into a CHECK.
2934 inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
2935 } else {
2936 std::vector<int64> output_non_degenerated;
2937 for (int64 i = output_dim_start; i < output_dim_end; ++i) {
2938 if (new_sizes[i] != 1) {
2939 output_non_degenerated.push_back(i);
2940 }
2941 }
2942 if (output_non_degenerated.size() == 1) {
2943 inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
2944 }
2945 }
2946 }
2947
2948 return inferred_shape;
2949 }
2950
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)2951 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
2952 const Shape& operand, absl::Span<const int64> dimensions) {
2953 TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
2954
2955 if (!IsPermutation(dimensions, operand.rank())) {
2956 return InvalidArgument(
2957 "Transpose dimensions [%s] are not a permutation of the operand "
2958 "dimensions (operand shape is %s).",
2959 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2960 }
2961
2962 // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
2963 // we need output[i]=input[dimensions[i]] which is
2964 // Permute(Inverse(dimensions),input).
2965 return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
2966 }
2967
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)2968 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
2969 const Shape& min, const Shape& operand, const Shape& max) {
2970 TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
2971 TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
2972 TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
2973
2974 if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
2975 !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
2976 return InvalidArgument(
2977 "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
2978 ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
2979 }
2980 return operand;
2981 }
2982
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2983 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
2984 const Shape& pred, const Shape& on_true, const Shape& on_false) {
2985 TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
2986 TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
2987 TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
2988
2989 if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
2990 return InvalidArgument(
2991 "Operands to select must be the same shape; got %s and %s.",
2992 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
2993 }
2994 if (pred.element_type() != PRED) {
2995 return InvalidArgument(
2996 "Select's pred operand must have PRED element type; got %s.",
2997 ShapeUtil::HumanString(pred));
2998 }
2999 if (!Shape::Equal()
3000 .IgnoreElementType()
3001 .IgnoreLayout()
3002 .IgnoreDynamicDimension()(pred, on_true)) {
3003 return InvalidArgument(
3004 "Operands to select and predicate must be the same shape; got %s and "
3005 "%s.",
3006 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3007 }
3008
3009 return ShapeUtil::ChangeElementType(
3010 pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3011 }
3012
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3013 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
3014 const Shape& pred, const Shape& on_true, const Shape& on_false) {
3015 // Select only defines the top-level buffer, so if it's a tuple, the two
3016 // input must match exactly.
3017 if (!ShapeUtil::Compatible(on_true, on_false)) {
3018 return InvalidArgument(
3019 "Operands to tuple-select must be the same shape; got %s and %s.",
3020 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3021 }
3022 if (pred.element_type() != PRED) {
3023 return InvalidArgument(
3024 "TupleSelect's pred operand must have PRED element type; got %s.",
3025 ShapeUtil::HumanString(pred));
3026 }
3027 if (!ShapeUtil::IsScalar(pred)) {
3028 return InvalidArgument(
3029 "TupleSelect operation with non-scalar predicate: %s.",
3030 ShapeUtil::HumanString(pred));
3031 }
3032 return on_true;
3033 }
3034
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3035 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3036 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3037 // The applied function's arity equals the number of arguments.
3038 if (arg_shapes.size() != to_apply.parameters_size()) {
3039 string computation_signature = ShapeUtil::HumanString(to_apply);
3040 string argument_shapes =
3041 StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
3042 absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3043 });
3044 return InvalidArgument(
3045 "Call applied function arity must match number of arguments; got: "
3046 "arity: %d, arguments: %u; computation signature: %s; argument "
3047 "shapes: [%s].",
3048 to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3049 argument_shapes);
3050 }
3051
3052 // All arguments must be compatible with the program shape.
3053 for (int i = 0; i < arg_shapes.size(); ++i) {
3054 const Shape& arg_shape = *arg_shapes[i];
3055 const Shape& param_shape = to_apply.parameters(i);
3056 if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3057 return InvalidArgument(
3058 "Call parameter must match argument; got parameter %d shape: %s, "
3059 "argument shape: %s.",
3060 i, ShapeUtil::HumanString(param_shape),
3061 ShapeUtil::HumanString(arg_shape));
3062 }
3063 }
3064
3065 return to_apply.result();
3066 }
3067
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3068 static Status ValidateGatherDimensionNumbers(
3069 const Shape& input_shape, absl::Span<const int64> start_indices_shape,
3070 const GatherDimensionNumbers& dim_numbers) {
3071 if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3072 return InvalidArgument(
3073 "Output window dimensions in gather op must be ascending; got: %s.",
3074 StrJoin(dim_numbers.offset_dims(), ", "));
3075 }
3076
3077 if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3078 dim_numbers.offset_dims().end()) {
3079 return InvalidArgument(
3080 "Output window dimensions in gather op must not repeat; got: %s.",
3081 StrJoin(dim_numbers.offset_dims(), ", "));
3082 }
3083
3084 const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
3085 const int64 output_shape_rank =
3086 output_offset_dim_count + start_indices_shape.size() - 1;
3087
3088 for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3089 int64 offset_dim = dim_numbers.offset_dims(i);
3090 if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3091 return InvalidArgument(
3092 "Offset dimension %d in gather op is out of bounds; got %d, but "
3093 "should "
3094 "have been in [0,%d).",
3095 i, offset_dim, output_shape_rank);
3096 }
3097 }
3098
3099 if (dim_numbers.start_index_map_size() !=
3100 start_indices_shape[dim_numbers.index_vector_dim()]) {
3101 return InvalidArgument(
3102 "Gather op has %d elements in start_index_map and the "
3103 "bound of dimension index_vector_dim=%d of start_indices is "
3104 "%d. These two numbers must be equal.",
3105 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3106 start_indices_shape[dim_numbers.index_vector_dim()]);
3107 }
3108
3109 for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3110 int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3111 if (operand_dim_for_start_index_i < 0 ||
3112 operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3113 return InvalidArgument(
3114 "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3115 input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3116 }
3117 }
3118
3119 std::vector<int64> sorted_start_index_map(
3120 dim_numbers.start_index_map().begin(),
3121 dim_numbers.start_index_map().end());
3122
3123 absl::c_sort(sorted_start_index_map);
3124
3125 if (absl::c_adjacent_find(sorted_start_index_map) !=
3126 sorted_start_index_map.end()) {
3127 return InvalidArgument(
3128 "Repeated dimensions are not allowed in start_index_map; "
3129 "got: %s.",
3130 StrJoin(dim_numbers.start_index_map(), ", "));
3131 }
3132
3133 for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3134 if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3135 return InvalidArgument(
3136 "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3137 "%d), got: %d.",
3138 input_shape.dimensions_size(), collapsed_dim);
3139 }
3140 }
3141
3142 if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3143 return InvalidArgument(
3144 "collapsed_slice_dims in gather op must be sorted; got: %s",
3145 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3146 }
3147
3148 if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3149 dim_numbers.collapsed_slice_dims().end()) {
3150 return InvalidArgument(
3151 "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3152 "got: %s.",
3153 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3154 }
3155
3156 return Status::OK();
3157 }
3158
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)3159 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3160 const Shape& input_shape, const Shape& start_indices_shape,
3161 const GatherDimensionNumbers& gather_dim_numbers,
3162 absl::Span<const int64> slice_sizes) {
3163 TF_RETURN_IF_ERROR(
3164 ExpectArray(input_shape, "input tensor operand of gather op"));
3165 TF_RETURN_IF_ERROR(
3166 ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3167
3168 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3169 return InvalidArgument(
3170 "Gather indices parameter must be an integral tensor; got %s.",
3171 ShapeUtil::HumanString(start_indices_shape));
3172 }
3173
3174 // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3175 // index_vector_dim is rank(P). The bounds of this expanded shape is
3176 // stored in expanded_start_indices_shape.
3177
3178 if (start_indices_shape.dimensions_size() <
3179 gather_dim_numbers.index_vector_dim() ||
3180 gather_dim_numbers.index_vector_dim() < 0) {
3181 return InvalidArgument(
3182 "Gather index leaf dimension must be within [0, rank(start_indices) + "
3183 "1). rank(start_indices) is %d and gather index leaf dimension is "
3184 "%d.",
3185 start_indices_shape.dimensions_size(),
3186 gather_dim_numbers.index_vector_dim());
3187 }
3188
3189 std::vector<int64> expanded_start_indices_shape;
3190 // Also tracks if an output dimension is dynamic.
3191 std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3192 expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3193 expanded_start_indices_shape_dynamic_dimensions.reserve(
3194 start_indices_shape.dimensions_size());
3195 absl::c_copy(start_indices_shape.dimensions(),
3196 std::back_inserter(expanded_start_indices_shape));
3197 absl::c_copy(
3198 start_indices_shape.dynamic_dimensions(),
3199 std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3200 if (expanded_start_indices_shape.size() ==
3201 gather_dim_numbers.index_vector_dim()) {
3202 expanded_start_indices_shape.push_back(1);
3203 expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3204 }
3205
3206 TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3207 input_shape, expanded_start_indices_shape, gather_dim_numbers));
3208
3209 if (slice_sizes.size() != input_shape.dimensions_size()) {
3210 return InvalidArgument(
3211 "Gather op must have one slice size for every input dimension; got: "
3212 "len(slice_sizes)=%lu, input_shape.rank=%d.",
3213 slice_sizes.size(), input_shape.dimensions_size());
3214 }
3215
3216 if (slice_sizes.size() !=
3217 gather_dim_numbers.offset_dims_size() +
3218 gather_dim_numbers.collapsed_slice_dims_size()) {
3219 return InvalidArgument(
3220 "All components of the offset index in a gather op must either be a "
3221 "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3222 "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3223 slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3224 StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3225 }
3226
3227 for (int i = 0; i < slice_sizes.size(); i++) {
3228 int64 slice_size = slice_sizes[i];
3229 int64 corresponding_input_size = input_shape.dimensions(i);
3230 if (slice_size < 0 || slice_size > corresponding_input_size) {
3231 return InvalidArgument(
3232 "Slice size at index %d in gather op is out of range, must be "
3233 "within [0, %d), got %d.",
3234 i, corresponding_input_size + 1, slice_size);
3235 }
3236 }
3237
3238 for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3239 if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3240 return InvalidArgument(
3241 "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3242 "is %d for index %d at position %d.",
3243 slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3244 gather_dim_numbers.collapsed_slice_dims(i), i);
3245 }
3246 }
3247
3248 int64 result_rank = gather_dim_numbers.offset_dims_size() +
3249 (expanded_start_indices_shape.size() - 1);
3250 int64 offset_dims_seen = 0;
3251 int64 gather_dims_seen = 0;
3252 std::vector<int64> output_dim_bounds;
3253 output_dim_bounds.reserve(result_rank);
3254
3255 std::vector<bool> output_dim_is_dynamic;
3256 output_dim_is_dynamic.reserve(result_rank);
3257 for (int64 i = 0; i < result_rank; i++) {
3258 int64 current_bound;
3259 bool dim_dynamic = false;
3260 bool is_window_index =
3261 absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3262 if (is_window_index) {
3263 while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3264 offset_dims_seen)) {
3265 offset_dims_seen++;
3266 }
3267 // Gathering an entire dynamic dimension creates dynamic dimension.
3268 //
3269 // e.g.,:
3270 //
3271 // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3272 //
3273 // creates
3274 //
3275 // [<=2, 1]
3276 if (slice_sizes[offset_dims_seen] ==
3277 input_shape.dimensions(offset_dims_seen)) {
3278 dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3279 }
3280 current_bound = slice_sizes[offset_dims_seen++];
3281 } else {
3282 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3283 gather_dims_seen++;
3284 }
3285 // Forward dynamic dimensions from indices.
3286 dim_dynamic =
3287 expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3288
3289 current_bound = expanded_start_indices_shape[gather_dims_seen++];
3290 }
3291 output_dim_is_dynamic.push_back(dim_dynamic);
3292 output_dim_bounds.push_back(current_bound);
3293 }
3294
3295 return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3296 output_dim_is_dynamic);
3297 }
3298
3299 namespace {
3300
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3301 Status ValidateScatterDimensionNumbers(
3302 const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3303 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3304 // Validate update_window_dims in ScatterDimensionNumbers.
3305 if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3306 return InvalidArgument(
3307 "update_window_dims in scatter op must be sorted; got: %s.",
3308 StrJoin(dim_numbers.update_window_dims(), ", "));
3309 }
3310 if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3311 dim_numbers.update_window_dims().end()) {
3312 return InvalidArgument(
3313 "update_window_dims in scatter op must not repeat; got: %s.",
3314 StrJoin(dim_numbers.update_window_dims(), ", "));
3315 }
3316 const int64 updates_rank = updates_shape.rank();
3317 for (int64 window_dim : dim_numbers.update_window_dims()) {
3318 if (window_dim < 0 || window_dim >= updates_rank) {
3319 return InvalidArgument(
3320 "Invalid update_window_dims set in scatter op; valid range is [0, "
3321 "%d). got: %d.",
3322 updates_rank, window_dim);
3323 }
3324 }
3325
3326 // Validate inserted_window_dims in ScatterDimensionNumbers.
3327 if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3328 return InvalidArgument(
3329 "inserted_window_dims in scatter op must be sorted; got: %s.",
3330 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3331 }
3332 if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3333 dim_numbers.inserted_window_dims().end()) {
3334 return InvalidArgument(
3335 "inserted_window_dims in scatter op must not repeat; got: %s.",
3336 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3337 }
3338 for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
3339 if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3340 return InvalidArgument(
3341 "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3342 "%d), got: %d.",
3343 operand_shape.dimensions_size(), inserted_dim);
3344 }
3345 }
3346
3347 // Validate window size.
3348 auto window_size = dim_numbers.update_window_dims_size() +
3349 dim_numbers.inserted_window_dims_size();
3350 if (window_size != operand_shape.rank()) {
3351 return InvalidArgument(
3352 "Scatter op has window of size %d; doesn't match operand of rank %d.",
3353 window_size, operand_shape.rank());
3354 }
3355
3356 // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3357 if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3358 scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3359 return InvalidArgument(
3360 "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3361 "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3362 "These two numbers must be equal.",
3363 dim_numbers.scatter_dims_to_operand_dims_size(),
3364 dim_numbers.index_vector_dim(),
3365 scatter_indices_shape[dim_numbers.index_vector_dim()]);
3366 }
3367 for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3368 int64 scatter_dim_to_operand_dim =
3369 dim_numbers.scatter_dims_to_operand_dims(i);
3370 if (scatter_dim_to_operand_dim < 0 ||
3371 scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3372 return InvalidArgument(
3373 "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3374 "got: %d->%d.",
3375 operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3376 }
3377 }
3378 std::vector<int64> sorted_scatter_dims_to_operand_dims(
3379 dim_numbers.scatter_dims_to_operand_dims().begin(),
3380 dim_numbers.scatter_dims_to_operand_dims().end());
3381 absl::c_sort(sorted_scatter_dims_to_operand_dims);
3382 if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3383 sorted_scatter_dims_to_operand_dims.end()) {
3384 return InvalidArgument(
3385 "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3386 "got: %s.",
3387 StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3388 }
3389
3390 return Status::OK();
3391 }
3392
3393 } // namespace
3394
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3395 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3396 const Shape& operand_shape, const Shape& scatter_indices_shape,
3397 const Shape& updates_shape, const ProgramShape& to_apply_shape,
3398 const ScatterDimensionNumbers& scatter_dim_numbers) {
3399 TF_RETURN_IF_ERROR(
3400 ExpectArray(operand_shape, "operand tensor of scatter op"));
3401 TF_RETURN_IF_ERROR(
3402 ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3403 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3404
3405 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3406 return InvalidArgument(
3407 "Scatter indices parameter must be an integral tensor; got %s.",
3408 ShapeUtil::HumanString(scatter_indices_shape));
3409 }
3410
3411 if (scatter_indices_shape.dimensions_size() <
3412 scatter_dim_numbers.index_vector_dim() ||
3413 scatter_dim_numbers.index_vector_dim() < 0) {
3414 return InvalidArgument(
3415 "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3416 " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3417 "is %d.",
3418 scatter_indices_shape.dimensions_size(),
3419 scatter_dim_numbers.index_vector_dim());
3420 }
3421
3422 // Check if the update computation has a proper shape as a reduction.
3423 const Shape init_value_shape =
3424 ShapeUtil::MakeShape(operand_shape.element_type(), {});
3425 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3426 {updates_shape.element_type()},
3427 /*inputs=*/1));
3428
3429 std::vector<int64> expanded_scatter_indices_shape =
3430 SpanToVector(scatter_indices_shape.dimensions());
3431 if (expanded_scatter_indices_shape.size() ==
3432 scatter_dim_numbers.index_vector_dim()) {
3433 expanded_scatter_indices_shape.push_back(1);
3434 }
3435
3436 int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3437 scatter_dim_numbers.update_window_dims_size();
3438 if (updates_shape.rank() != expected_updates_rank) {
3439 return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3440 expected_updates_rank, updates_shape.rank());
3441 }
3442
3443 TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3444 operand_shape, expanded_scatter_indices_shape, updates_shape,
3445 scatter_dim_numbers));
3446
3447 int64 inserted_dims_seen = 0;
3448 std::vector<int64> max_update_slice_sizes;
3449 for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3450 if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3451 scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3452 ++inserted_dims_seen;
3453 } else {
3454 max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3455 }
3456 }
3457 for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3458 auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3459 if (updates_shape.dimensions(update_window_dim) >
3460 max_update_slice_sizes[i]) {
3461 return InvalidArgument(
3462 "Bounds of the window dimensions of updates must not exceed the "
3463 "bounds of the corresponding dimensions of operand. For dimension "
3464 "%d, updates bound is %d, operand bound is %d.",
3465 update_window_dim, updates_shape.dimensions(update_window_dim),
3466 max_update_slice_sizes[i]);
3467 }
3468 }
3469
3470 int64 scatter_dims_seen = 0;
3471 for (int64 i = 0; i < updates_shape.rank(); ++i) {
3472 bool is_update_window_dim =
3473 absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3474 if (is_update_window_dim) {
3475 continue;
3476 }
3477 if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3478 ++scatter_dims_seen;
3479 }
3480 if (updates_shape.dimensions(i) !=
3481 expanded_scatter_indices_shape[scatter_dims_seen]) {
3482 return InvalidArgument(
3483 "Bounds of the scatter dimensions of updates must be same as the "
3484 "bounds of the corresponding dimensions of scatter indices. For "
3485 "scatter dimension %d, updates bound is %d, scatter_indices "
3486 "bound is %d.",
3487 i, updates_shape.dimensions(i),
3488 expanded_scatter_indices_shape[scatter_dims_seen]);
3489 }
3490 ++scatter_dims_seen;
3491 }
3492
3493 return operand_shape;
3494 }
3495
3496 } // namespace xla
3497