1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/math/math_util.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/protobuf.h"
40
41 namespace xla {
42 namespace {
43
44 using absl::StrFormat;
45 using absl::StrJoin;
46
47 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64> slice)48 bool AllUnique(absl::Span<const int64> slice) {
49 return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
50 }
51
ExpectArray(const Shape & shape,absl::string_view op_type)52 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
53 if (!shape.IsArray()) {
54 return InvalidArgument("Expected array argument for %s, but got %s.",
55 string(op_type), ShapeUtil::HumanString(shape));
56 }
57 return Status::OK();
58 }
59
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64 inputs)60 Status VerifyReducerShape(const ProgramShape& reducer_shape,
61 absl::Span<const Shape* const> init_value_shapes,
62 absl::Span<const PrimitiveType> input_element_types,
63 int64 inputs) {
64 if (reducer_shape.parameters_size() != inputs * 2) {
65 return InvalidArgument(
66 "Reduction function must take %d parameters, but "
67 "takes %d parameter(s).",
68 inputs * 2, reducer_shape.parameters_size());
69 }
70
71 const Shape& accumulator_shape = reducer_shape.result();
72 std::vector<const Shape*> accumulator_subshapes;
73 if (accumulator_shape.IsArray()) {
74 if (inputs != 1) {
75 return InvalidArgument(
76 "Reduction function must produce a tuple with %d elements, but "
77 "produces a scalar",
78 inputs);
79 }
80 accumulator_subshapes.push_back(&accumulator_shape);
81 } else if (accumulator_shape.IsTuple()) {
82 if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
83 return InvalidArgument(
84 "Reduction function must produce a tuple with %d elements, but has "
85 "%d elements",
86 inputs, ShapeUtil::TupleElementCount(accumulator_shape));
87 }
88 for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
89 accumulator_subshapes.push_back(&element_shape);
90 }
91 } else {
92 return InvalidArgument(
93 "Reduction function must produce a scalar or tuple of scalars, but has "
94 "shape: %s",
95 ShapeUtil::HumanString(accumulator_shape));
96 }
97
98 for (const Shape* element_shape : accumulator_subshapes) {
99 if (element_shape->rank() != 0) {
100 return InvalidArgument(
101 "Reduction function must return a scalar or tuple of scalars but "
102 "returns shape: %s",
103 ShapeUtil::HumanString(accumulator_shape));
104 }
105 }
106
107 for (int64 i = 0; i < inputs; ++i) {
108 // Check that the accumulator can be passed in as the first argument.
109 // Note: comparing here and below with Compatible since we don't care about
110 // layout in scalars - see b/26668201 for a longer-term vision.
111 if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
112 reducer_shape.parameters(i))) {
113 return InvalidArgument(
114 "Reduction function's %d-th parameter shape differs from the "
115 "result shape: %s vs %s",
116 i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
117 ShapeUtil::HumanString(*accumulator_subshapes[i]));
118 }
119 // Check that init_value's shapes are suitable for reducer_shape.
120 if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
121 *init_value_shapes[i])) {
122 return InvalidArgument(
123 "Reduction function's accumulator shape at index %d differs from "
124 "the init_value shape: %s vs %s",
125 i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
126 ShapeUtil::HumanString(*init_value_shapes[i]));
127 }
128 // Check that the inputs can be passed in as the non-accumulator arguments.
129 const Shape input_element_shape =
130 ShapeUtil::MakeShape(input_element_types[i], {});
131 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
132 input_element_shape, reducer_shape.parameters(inputs + i))) {
133 return InvalidArgument(
134 "Reduction function's %d-th parameter shape differs from the "
135 "input type element type: %s vs %s",
136 inputs + i,
137 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
138 ShapeUtil::HumanString(input_element_shape));
139 }
140 // Check that the accumulator and inputs to the reducer function match.
141 // If the accumulator is scalar, it must have the same type as the inputs
142 // (up to fp precision). If it is a tuple, then the k-th element of the
143 // tuple must have the same type as the K-th input (again, up to fp
144 // precision.)
145 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
146 *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
147 return InvalidArgument(
148 "Reduction function's %d-th parameter shape must "
149 "match the result shape, but got %s vs %s.",
150 inputs + i,
151 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
152 ShapeUtil::HumanString(*accumulator_subshapes[i]));
153 }
154 }
155
156 return Status::OK();
157 }
158
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type,bool allow_negative_padding)159 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
160 const Window& window,
161 PrimitiveType element_type,
162 bool allow_negative_padding) {
163 if (window.dimensions_size() != base_shape.rank()) {
164 return InvalidArgument(
165 "Window has dimension %d but base shape has dimension %d.",
166 window.dimensions_size(), base_shape.rank());
167 }
168
169 std::vector<int64> output_dimensions(window.dimensions_size());
170 std::vector<bool> output_is_dynamic(window.dimensions_size());
171 for (int64 i = 0; i < window.dimensions_size(); ++i) {
172 const auto& dim = window.dimensions(i);
173 if (dim.size() <= 0) {
174 return InvalidArgument("Window %s has a non-positive dimension.",
175 window.DebugString());
176 }
177 if (dim.stride() <= 0) {
178 return InvalidArgument("Window %s has a non-positive stride.",
179 window.DebugString());
180 }
181 if (!allow_negative_padding && dim.padding_low() < 0) {
182 return InvalidArgument("Window %s has a negative low padding.",
183 window.DebugString());
184 }
185 if (!allow_negative_padding && dim.padding_high() < 0) {
186 return InvalidArgument("Window %s has a negative high padding.",
187 window.DebugString());
188 }
189 if (dim.base_dilation() < 1) {
190 return InvalidArgument(
191 "Window %s has a non-positive base area dilation factor.",
192 window.DebugString());
193 }
194 if (dim.window_dilation() < 1) {
195 return InvalidArgument(
196 "Window %s has a non-positive window dilation factor.",
197 window.DebugString());
198 }
199
200 if (base_shape.is_dynamic_dimension(i) &&
201 !window_util::IsTrivialWindowDimension(dim)) {
202 return Unimplemented(
203 "Dynamic shape is not supported for non trivial window: %s",
204 window_util::ToString(window));
205 }
206
207 const int64 dilated_base = window_util::DilatedBound(
208 ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
209 const int64 padded_dilated_base =
210 dim.padding_low() + dilated_base + dim.padding_high();
211 const int64 dilated_window =
212 window_util::DilatedBound(dim.size(), dim.window_dilation());
213
214 output_dimensions[i] = window_util::StridedBound(
215 padded_dilated_base, dilated_window, dim.stride());
216 output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
217 }
218
219 return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
220 output_is_dynamic);
221 }
222
223 } // namespace
224
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)225 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
226 HloOpcode opcode, const HloInstruction* operand) {
227 return InferUnaryOpShape(opcode, operand->shape());
228 }
229
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)230 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
231 HloOpcode opcode, const Shape& shape) {
232 // There is no copy operation at the proto level, so handle copy explicitly.
233 // A domain shape is the same as the input one.
234 if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
235 return shape;
236 }
237
238 TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
239
240 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
241 switch (opcode) {
242 case HloOpcode::kFloor:
243 case HloOpcode::kCeil:
244 case HloOpcode::kRoundNearestAfz:
245 if (!ShapeUtil::ElementIsFloating(shape)) {
246 return InvalidArgument(
247 "Expected element type in shape to be floating for %s operation; "
248 "got %s.",
249 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
250 }
251 return shape;
252 case HloOpcode::kCos:
253 case HloOpcode::kSin:
254 case HloOpcode::kExp:
255 case HloOpcode::kExpm1:
256 case HloOpcode::kLog:
257 case HloOpcode::kLog1p:
258 case HloOpcode::kRsqrt:
259 case HloOpcode::kSqrt:
260 case HloOpcode::kTanh:
261 if (!ShapeUtil::ElementIsFloating(shape) &&
262 !ShapeUtil::ElementIsComplex(shape)) {
263 return InvalidArgument(
264 "Expected element type in shape to be floating or complex for %s "
265 "operation; got %s.",
266 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
267 }
268 return shape;
269 case HloOpcode::kReal:
270 case HloOpcode::kImag:
271 if (ShapeUtil::ElementIsComplex(shape)) {
272 return ShapeUtil::ComplexComponentShape(shape);
273 } else if (ShapeUtil::ElementIsFloating(shape)) {
274 return shape;
275 } else {
276 return InvalidArgument(
277 "Expected element type in shape to be floating or complex for "
278 "%s operation; got %s.",
279 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
280 }
281 case HloOpcode::kAbs:
282 if (ShapeUtil::ElementIsComplex(shape)) {
283 return ShapeUtil::ChangeElementType(
284 shape, primitive_util::ComplexComponentType(shape.element_type()));
285 } else if (ShapeUtil::ElementIsSigned(shape)) {
286 return shape;
287 } else {
288 return InvalidArgument(
289 "Expected element type in shape to be floating or complex for "
290 "%s operation; got %s.",
291 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
292 }
293 case HloOpcode::kClz:
294 if (!ShapeUtil::ElementIsIntegral(shape)) {
295 return InvalidArgument(
296 "Expected an integral element type in argument to Clz "
297 "operation; got %s.",
298 PrimitiveType_Name(shape.element_type()));
299 }
300 return shape;
301 case HloOpcode::kNegate:
302 if (!ShapeUtil::ElementIsIntegral(shape) &&
303 !ShapeUtil::ElementIsFloating(shape) &&
304 !ShapeUtil::ElementIsComplex(shape)) {
305 return InvalidArgument(
306 "Expected element type in shape to be integral, floating or "
307 "complex for %s operation; got %s.",
308 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
309 }
310 return shape;
311 case HloOpcode::kSign:
312 if (!ShapeUtil::ElementIsSigned(shape) &&
313 !ShapeUtil::ElementIsComplex(shape)) {
314 return InvalidArgument(
315 "Expected element type in shape to be signed or complex for "
316 "%s operation; got %s.",
317 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
318 }
319 return shape;
320
321 case HloOpcode::kNot:
322 if (shape.element_type() != PRED &&
323 !primitive_util::IsIntegralType(shape.element_type())) {
324 return InvalidArgument(
325 "Expected pred or an integral element type in argument to Not "
326 "operation; got %s.",
327 PrimitiveType_Name(shape.element_type()));
328 }
329 return shape;
330
331 case HloOpcode::kIsFinite:
332 if (!ShapeUtil::ElementIsFloating(shape)) {
333 return InvalidArgument(
334 "Expected element type in shape to be floating "
335 "point for IsFinite "
336 "operation; got %s.",
337 PrimitiveType_Name(shape.element_type()));
338 }
339 return ShapeUtil::ChangeElementType(shape, PRED);
340
341 default:
342 return InvalidArgument(
343 "Unknown operation for unary shape inference: \"%s\".",
344 HloOpcodeString(opcode));
345 }
346 }
347
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64 dimension)348 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
349 absl::Span<const Shape* const> arg_shapes, const int64 dimension) {
350 if (arg_shapes.empty()) {
351 return InvalidArgument("Concatenate expects at least one argument.");
352 }
353 if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
354 return InvalidArgument("Concatenate dimension out of bounds: %d.",
355 dimension);
356 }
357 const Shape* arg_shape = nullptr;
358 PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
359 for (const Shape* shape : arg_shapes) {
360 TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
361 if (!arg_shape) {
362 arg_shape = shape;
363 element_type = arg_shape->element_type();
364 continue;
365 }
366 if (arg_shape->rank() != shape->rank()) {
367 return InvalidArgument(
368 "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
369 "(%s).",
370 arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
371 ShapeUtil::HumanString(*shape));
372 }
373 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
374 return InvalidArgument(
375 "Cannot concatenate arrays with different element types: %s vs %s.",
376 PrimitiveType_Name(arg_shape->element_type()),
377 PrimitiveType_Name(shape->element_type()));
378 }
379 for (int64 dimension_number = 0; dimension_number < arg_shape->rank();
380 ++dimension_number) {
381 if (arg_shape->dimensions(dimension_number) !=
382 shape->dimensions(dimension_number)) {
383 if (dimension_number == dimension) {
384 continue; // It's okay to differ in the dimension we're
385 // concatenating.
386 }
387 return InvalidArgument(
388 "Cannot concatenate arrays that differ in dimensions other than "
389 "the one being concatenated (the other array dimensions must be "
390 "the same): %s vs %s in dimension %d.",
391 ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape),
392 dimension);
393 }
394 }
395 element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
396 }
397
398 std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
399 arg_shape->dimensions().end());
400 for (size_t i = 1; i < arg_shapes.size(); ++i) {
401 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
402 }
403 return ShapeUtil::MakeShape(element_type, new_dimensions);
404 }
405
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)406 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
407 const Shape& operand_shape, PrimitiveType new_element_type) {
408 auto old_element_type = operand_shape.element_type();
409 if (primitive_util::IsComplexType(old_element_type) &&
410 !primitive_util::IsComplexType(new_element_type)) {
411 return Unimplemented(
412 "Conversion from complex to real type %s => %s is not implemented.",
413 ShapeUtil::HumanString(operand_shape),
414 PrimitiveType_Name(new_element_type));
415 }
416 if (!operand_shape.IsArray() ||
417 !primitive_util::IsArrayType(new_element_type)) {
418 // Note: we may want to support tuple conversions via this operation in the
419 // future, by recursing into the tuple elements to check all sub-conversions
420 // are valid. For now we just reject them, though.
421 return InvalidArgument(
422 "Convert does not allow non-arrays, so cannot convert from %s to %s.",
423 ShapeUtil::HumanString(operand_shape),
424 PrimitiveType_Name(new_element_type));
425 }
426
427 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
428 }
429
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)430 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
431 const Shape& operand_shape, PrimitiveType new_element_type) {
432 auto old_element_type = operand_shape.element_type();
433 if (primitive_util::IsComplexType(old_element_type) !=
434 primitive_util::IsComplexType(new_element_type)) {
435 return InvalidArgument("Conversion from complex to real type %s => %s.",
436 ShapeUtil::HumanString(operand_shape),
437 PrimitiveType_Name(new_element_type));
438 }
439 if (!operand_shape.IsArray() ||
440 !primitive_util::IsArrayType(new_element_type)) {
441 // Note: we may want to support tuple conversions via this operation in the
442 // future, by recursing into the tuple elements to check all sub-conversions
443 // are valid. For now we just reject them, though.
444 return InvalidArgument(
445 "Cannot convert from or to tuple type; requested conversion: %s => %s.",
446 ShapeUtil::HumanString(operand_shape),
447 PrimitiveType_Name(new_element_type));
448 }
449 if (primitive_util::BitWidth(old_element_type) !=
450 primitive_util::BitWidth(new_element_type)) {
451 return InvalidArgument(
452 "Cannot bitcast types with different bit-widths: %s => %s.",
453 PrimitiveType_Name(old_element_type),
454 PrimitiveType_Name(new_element_type));
455 }
456
457 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
458 }
459
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)460 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
461 const Shape& operand_shape, const int exponent_bits,
462 const int mantissa_bits) {
463 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
464 return InvalidArgument(
465 "Expected element type in shape to be floating point for "
466 "ReducePrecision operation; got %s.",
467 PrimitiveType_Name(operand_shape.element_type()));
468 }
469 if (exponent_bits < 1) {
470 // One exponent bit is necessary to distinguish 0 from infinity. Having
471 // no exponent bits doesn't produce a sensible number, so we require at
472 // least one.
473 return InvalidArgument("Expected exponent_bits >= 1; got %d.",
474 exponent_bits);
475 }
476 if (mantissa_bits < 0) {
477 // A number with no mantissa bits is still meaningful, however.
478 return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
479 mantissa_bits);
480 }
481 return operand_shape;
482 }
483
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)484 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
485 const Shape& operand_shape, const Shape& padding_value_shape,
486 const PaddingConfig& padding_config) {
487 if (!operand_shape.IsArray()) {
488 return InvalidArgument(
489 "Pad operation does not support tuple-shape operands.");
490 }
491 if (!ShapeUtil::IsScalar(padding_value_shape)) {
492 return InvalidArgument(
493 "Pad operation does not support non-scalar padding values.");
494 }
495 if (operand_shape.rank() != padding_config.dimensions_size()) {
496 return InvalidArgument(
497 "The rank of the operand and the padding configuration do not match: "
498 "%s vs %s.",
499 ShapeUtil::HumanString(operand_shape),
500 padding_config.ShortDebugString());
501 }
502 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
503 padding_value_shape)) {
504 return InvalidArgument(
505 "The element types of the operands to Pad do not match.");
506 }
507 if (absl::c_any_of(padding_config.dimensions(),
508 [](const PaddingConfig::PaddingConfigDimension& p) {
509 return p.interior_padding() < 0;
510 })) {
511 return InvalidArgument("Interior padding cannot be negative: %s",
512 padding_config.ShortDebugString());
513 }
514
515 if (!padding_value_shape.is_static()) {
516 return InvalidArgument("Dynamic padding value is not supported");
517 }
518
519 std::vector<int64> dimensions(operand_shape.rank());
520 std::vector<bool> is_dynamic(operand_shape.rank());
521 for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
522 const auto& p = padding_config.dimensions(i);
523 if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 &&
524 p.edge_padding_low() != 0 && p.interior_padding() != 0) {
525 return InvalidArgument(
526 "Dynamic dimension on padding dimension is not supported.");
527 }
528 dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
529 p.edge_padding_high() +
530 std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
531 p.interior_padding();
532 if (dimensions[i] < 0) {
533 return InvalidArgument("Padding result in negative size for dimension %d",
534 i);
535 }
536 is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
537 }
538
539 return ShapeUtil::MakeShape(
540 ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
541 dimensions, is_dynamic);
542 }
543
544 // Current DotDimensionNumbers Requirements:
545 //
546 // Contracting Dimensions:
547 // *) Same number of contracting dimensions on both lhs and rhs.
548 // *) Contracting dimension size must be the same on both lhs and rhs.
549 //
550 // Batch Dimensions:
551 // *) Same number of batch dimensions on both lhs and rhs.
552 // *) Same batch dimension sizes on both lhs and rhs.
553 //
554
555 namespace {
556
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)557 Status ValidateDotDimensionNumbers(
558 const Shape& lhs, const Shape& rhs,
559 const DotDimensionNumbers& dimension_numbers) {
560 // Check that dimension numbers are in range.
561 auto dims_in_range = [](const int64 rank,
562 absl::Span<const int64> contracting_dims,
563 absl::Span<const int64> batch_dims) -> bool {
564 auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
565 return absl::c_all_of(contracting_dims, in_range) &&
566 absl::c_all_of(batch_dims, in_range);
567 };
568
569 absl::Span<const int64> lhs_contracting_dimensions =
570 AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
571 absl::Span<const int64> rhs_contracting_dimensions =
572 AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
573 absl::Span<const int64> lhs_batch_dimensions =
574 AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
575 absl::Span<const int64> rhs_batch_dimensions =
576 AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
577
578 if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
579 lhs_batch_dimensions) ||
580 !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
581 rhs_batch_dimensions)) {
582 return InvalidArgument("A dimension number is out of range in Dot: %s.",
583 dimension_numbers.DebugString());
584 }
585
586 // Check that dimension numbers are unique.
587 auto dims_unique = [](absl::Span<const int64> contracting_dims,
588 absl::Span<const int64> batch_dims) -> bool {
589 absl::flat_hash_set<int64> dim_set;
590 auto is_unique = [&dim_set](int64 i) -> bool {
591 return dim_set.insert(i).second;
592 };
593 return absl::c_all_of(contracting_dims, is_unique) &&
594 absl::c_all_of(batch_dims, is_unique);
595 };
596
597 if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
598 !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
599 return InvalidArgument("A dimension number is not unique in Dot: %s.",
600 dimension_numbers.DebugString());
601 }
602
603 return Status::OK();
604 }
605
606 } // namespace
607
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)608 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
609 const Shape& lhs, const Shape& rhs,
610 const DotDimensionNumbers& dimension_numbers) {
611 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
612 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
613
614 auto fail = [lhs, rhs](const string& addendum) -> Status {
615 string message =
616 StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
617 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
618 if (!addendum.empty()) {
619 message += " " + addendum;
620 }
621 return InvalidArgument("%s", message);
622 };
623
624 // Check if both element types are the same.
625 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
626 return fail("Element types do not match.");
627 }
628
629 if ((lhs.rank() < 1) || (rhs.rank() < 1)) {
630 return fail("Dot only supports rank 1 or above.");
631 }
632
633 // Validate basic properties of dot dimension numbers.
634 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
635
636 // Check that number of contracting dimensions match.
637 if (dimension_numbers.lhs_contracting_dimensions_size() !=
638 dimension_numbers.rhs_contracting_dimensions_size()) {
639 return fail(
640 "Must specify the same number of contracting dimensions for lhs and "
641 "rhs.");
642 }
643 // Check that contracting dimension sizes match.
644 for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
645 ++i) {
646 const int64 lhs_contracting_dimension =
647 dimension_numbers.lhs_contracting_dimensions(i);
648 const int64 rhs_contracting_dimension =
649 dimension_numbers.rhs_contracting_dimensions(i);
650 if (lhs.dimensions(lhs_contracting_dimension) !=
651 rhs.dimensions(rhs_contracting_dimension) ||
652 lhs.is_dynamic_dimension(lhs_contracting_dimension) !=
653 rhs.is_dynamic_dimension(rhs_contracting_dimension)) {
654 return fail("Contracting dimension sizes do not match.");
655 }
656 }
657
658 // Check that number of batch dimensions match.
659 if (dimension_numbers.lhs_batch_dimensions_size() !=
660 dimension_numbers.rhs_batch_dimensions_size()) {
661 return fail("Must the same number of batch dimensions for lhs and rhs.");
662 }
663
664 // Check that batch dimension numbers and sizes match.
665 for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
666 if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
667 rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) ||
668 lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) !=
669 rhs.is_dynamic_dimension(
670 dimension_numbers.rhs_batch_dimensions(i))) {
671 return fail("Batch dimension sizes must match for lhs/rhs.");
672 }
673 }
674
675 // The ranks of lhs and rhs are decremented by 1 respectively due to the
676 // contraction, and added for the rank of the result. When an input tensor is
677 // a scalar, its contribution to the rank of the result is 0.
678 // Generate the result dimensions in order, rhs dimensions followed by lhs
679 // dimensions except the contracted and batch dimensions.
680 std::vector<int64> dimensions;
681 std::vector<bool> is_dynamic;
682 for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
683 dimensions.push_back(lhs.dimensions(lhs_dim));
684 is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
685 }
686 for (int64 i = 0; i < lhs.rank(); i++) {
687 if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
688 i) &&
689 !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
690 dimensions.push_back(lhs.dimensions(i));
691 is_dynamic.push_back(lhs.is_dynamic_dimension(i));
692 }
693 }
694 for (int64 i = 0; i < rhs.rank(); i++) {
695 if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
696 i) &&
697 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
698 dimensions.push_back(rhs.dimensions(i));
699 is_dynamic.push_back(rhs.is_dynamic_dimension(i));
700 }
701 }
702 Shape result = ShapeUtil::MakeShape(
703 ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
704
705 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
706 VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
707 return result;
708 }
709
710 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)711 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
712 const Shape& lhs,
713 const Shape& rhs) {
714 TF_RET_CHECK(lhs.rank() == rhs.rank());
715
716 // The shapes have to be compatible. That is, if some dimension d has a
717 // different size in the two shapes, one of them has to be 1 (a "degenerate"
718 // dimension). In that case, the output shape has the non-1 dimension size
719 // from the lhs/rhs pair in every index.
720 std::vector<int64> output_dimensions(lhs.rank());
721 std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
722 for (int64 i = 0; i < lhs.rank(); ++i) {
723 if (lhs.dimensions(i) == rhs.dimensions(i)) {
724 output_dimensions[i] = lhs.dimensions(i);
725 output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
726 } else if (lhs.dimensions(i) == 1) {
727 output_dimensions[i] = rhs.dimensions(i);
728 output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i);
729 } else if (rhs.dimensions(i) == 1) {
730 output_dimensions[i] = lhs.dimensions(i);
731 output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
732 } else {
733 return InvalidArgument(
734 "Binary op %s with incompatible shapes: %s and %s.",
735 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
736 ShapeUtil::HumanString(rhs));
737 }
738 }
739 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
740 output_dimensions, output_dimensions_is_dynamic);
741 }
742
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64> broadcast_dimensions)743 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
744 const Shape& smaller_shape, const Shape& larger_shape,
745 absl::Span<const int64> broadcast_dimensions) {
746 if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
747 // Reject "magic" inference for binops on different shapes, requiring
748 // the user to provide an explicit broadcast dimension in this case.
749 // See b/25177275 for more details.
750 return InvalidArgument("Automatic shape inference not supported: %s and %s",
751 ShapeUtil::HumanString(smaller_shape),
752 ShapeUtil::HumanString(larger_shape));
753 } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
754 return InvalidArgument(
755 "Size of broadcast_dimensions has to match lower-rank operand's "
756 "rank; "
757 " lower-rank operand's rank is %d, size of broadcast_dimensions is "
758 "%u.",
759 smaller_shape.rank(), broadcast_dimensions.size());
760 }
761
762 // broadcast_dimensions is a sequence of dimensions; its length is equal to
763 // the rank of the lower-rank operand. The lower-rank operand's dimensions
764 // have to be compatible with the higher-rank operand's dimensions at indices
765 // specified by broadcast_dimensions. Here compatible means the dimension
766 // sizes are equal or in one of the shapes the dimension size is
767 // one. Examples:
768 //
769 // smaller_shape larger_shape broadcast_dimensions output_shape
770 // [] [2, 3] {} [2, 3]
771 // [3] [4, 3] {1} [4, 3]
772 // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
773 // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
774 // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
775 //
776 // The column output_shape may not be the final shape of the XLA
777 // operation. After the "InDim" broadcasting implemented in this function
778 // expands the rank, degenerate-dimension broadcasting (implemented in
779 // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
780 // up to match the dimension size of the other operand. For example, consider
781 // the row in the table above with a smaller_shape of [2, 1]. The shape
782 // returned by this function is [2, 3, 1] (output_shape) however, the result
783 // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
784 // broadcasting.
785 //
786 // Invalid broadcasts:
787 //
788 // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
789 // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
790 // dimension zero of smaller_shape(size 3). **Zero here comes from the value
791 // in broadcast_dimensions.
792 //
793 // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
794 // Reason: Dimension one of larger_shape (size 3) is not compatible with
795 // dimension zero of smaller_shape(size 2)
796
797 // The output shape is initially the larger_shape. Sizes of dimensions
798 // specified in broadcast_dimensions are then changed to match the
799 // corresponding dimension size in smaller_shape.
800 Shape output_shape(larger_shape);
801 output_shape.set_element_type(
802 ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
803
804 for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
805 int64 dimension_to_match = broadcast_dimensions.at(i);
806 if (dimension_to_match < 0) {
807 return InvalidArgument(
808 "Broadcast dimension number (%d) cannot be negative.",
809 dimension_to_match);
810 }
811 if (dimension_to_match >= larger_shape.dimensions_size()) {
812 return InvalidArgument(
813 "Broadcast dimension number (%d) too large; higher-rank "
814 "operand has rank %d.",
815 dimension_to_match, larger_shape.dimensions_size());
816 }
817 int64 small_dimension_size = smaller_shape.dimensions(i);
818 int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
819 bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
820 bool large_is_dynamic =
821 larger_shape.is_dynamic_dimension(dimension_to_match);
822 // Dimension sizes must be compatible: match or be degenerate (degenerate
823 // case is handled by degenerate dimension broadcasting which occurs after
824 // InDim broadcasting).
825 if (small_dimension_size != large_dimension_size &&
826 small_dimension_size != 1 && large_dimension_size != 1) {
827 return InvalidArgument(
828 "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
829 small_dimension_size, large_dimension_size,
830 ShapeUtil::HumanString(smaller_shape),
831 ShapeUtil::HumanString(larger_shape));
832 }
833 if (small_is_dynamic != large_is_dynamic) {
834 if (small_dimension_size == large_dimension_size ||
835 (small_dimension_size == 1 && !small_is_dynamic) ||
836 (large_dimension_size == 1 && !large_is_dynamic)) {
837 // Do nothing. It's OK when the size-1 dimension is not static.
838 } else {
839 return InvalidArgument(
840 "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
841 ShapeUtil::HumanString(smaller_shape),
842 ShapeUtil::HumanString(larger_shape));
843 }
844 }
845 // Make sure the broadcast dimensions are listed in a strictly increasing
846 // order.
847 if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
848 return InvalidArgument(
849 "Broadcast dimensions order is wrong: %d comes after %d.",
850 dimension_to_match, broadcast_dimensions.at(i - 1));
851 }
852
853 output_shape.set_dimensions(dimension_to_match, small_dimension_size);
854 output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
855 }
856
857 return output_shape;
858 }
859
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)860 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
861 HloOpcode operation, const Shape& lhs, const Shape& rhs,
862 absl::Span<const int64> broadcast_dimensions) {
863 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
864 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
865
866 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
867 return InvalidArgument(
868 "Binary op %s with different element types: %s and %s.",
869 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
870 ShapeUtil::HumanString(rhs));
871 }
872
873 if (lhs.rank() == rhs.rank()) {
874 std::vector<int64> identity_dims(lhs.rank());
875 std::iota(identity_dims.begin(), identity_dims.end(), 0);
876 if (!broadcast_dimensions.empty() &&
877 broadcast_dimensions != identity_dims) {
878 return InvalidArgument(
879 "Broadcast dimensions field must either be not set or be the "
880 "identity on binary operations with operands of the same rank.");
881 }
882 }
883
884 if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
885 // If the shapes are the same other than layout, the output shape is the
886 // same (elementwise op).
887 return ShapeUtil::ChangeElementType(
888 lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
889 }
890
891 if (lhs.rank() == rhs.rank()) {
892 return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
893 } else {
894 // Ranks do not match, so perform InDim broadcasting using
895 // broadcast_dimensions. Scalar broadcasting is a special case of this.
896 const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
897 const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
898
899 // After InDim broadcasting, perform degenerate dimensions broadcasting.
900 TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
901 InferInDimBroadcastShape(smaller_shape, larger_shape,
902 broadcast_dimensions));
903
904 return InferDegenerateDimensionBroadcastShape(
905 operation, indim_broadcast_shape, larger_shape);
906 }
907 }
908
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)909 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
910 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
911 return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
912 /*broadcast_dimensions=*/{});
913 }
914
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64> broadcast_dimensions)915 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
916 HloOpcode opcode, const Shape& lhs, const Shape& rhs,
917 absl::Span<const int64> broadcast_dimensions) {
918 VLOG(2) << StrFormat(
919 "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
920 HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
921 ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
922 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
923 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
924
925 TF_RETURN_IF_ERROR(ExpectArray(
926 lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
927 TF_RETURN_IF_ERROR(ExpectArray(
928 rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
929 switch (opcode) {
930 case HloOpcode::kMaximum:
931 case HloOpcode::kMinimum:
932 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
933 broadcast_dimensions);
934
935 case HloOpcode::kSubtract:
936 case HloOpcode::kAdd:
937 case HloOpcode::kAtan2:
938 case HloOpcode::kPower:
939 case HloOpcode::kDivide:
940 case HloOpcode::kRemainder:
941 case HloOpcode::kMultiply:
942 case HloOpcode::kShiftLeft:
943 case HloOpcode::kShiftRightArithmetic:
944 case HloOpcode::kShiftRightLogical:
945 if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
946 return InvalidArgument(
947 "Expected element type in shape to be arithmetic type for "
948 "operation %s; got PRED.",
949 HloOpcodeString(opcode));
950 }
951 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
952 broadcast_dimensions);
953
954 case HloOpcode::kComplex: {
955 if (!ShapeUtil::ElementIsFloating(lhs)) {
956 return InvalidArgument(
957 "Expected element type in shape to be floating for complex compose "
958 "operation; got %s.",
959 PrimitiveType_Name(lhs.element_type()));
960 }
961 TF_ASSIGN_OR_RETURN(const Shape& shape,
962 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
963 broadcast_dimensions));
964 if (lhs.element_type() == F32 && rhs.element_type() == F32) {
965 return ShapeUtil::ChangeElementType(shape, C64);
966 } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
967 return ShapeUtil::ChangeElementType(shape, C128);
968 } else {
969 return Unimplemented("Complex component type is not implemented.");
970 }
971 }
972 case HloOpcode::kAnd:
973 case HloOpcode::kOr:
974 case HloOpcode::kXor:
975 if (lhs.element_type() != PRED &&
976 !primitive_util::IsIntegralType(lhs.element_type())) {
977 return InvalidArgument(
978 "Expected pred or integral type in argument to and/or operation; "
979 "got %s.",
980 PrimitiveType_Name(lhs.element_type()));
981 }
982 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
983 broadcast_dimensions);
984 case HloOpcode::kCompare: {
985 TF_ASSIGN_OR_RETURN(const Shape& shape,
986 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
987 broadcast_dimensions));
988 return ShapeUtil::ChangeElementType(shape, PRED);
989 }
990 default:
991 return Unimplemented(
992 "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
993 HloOpcodeString(opcode), lhs.ShortDebugString(),
994 rhs.ShortDebugString());
995 }
996 }
997
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)998 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
999 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1000 const HloInstruction* ehs) {
1001 return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1002 }
1003
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1004 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1005 HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1006 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1007 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1008 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1009 switch (opcode) {
1010 case HloOpcode::kClamp:
1011 return InferClampShape(lhs, rhs, ehs);
1012 case HloOpcode::kSelect:
1013 return InferSelectShape(lhs, rhs, ehs);
1014 case HloOpcode::kTupleSelect:
1015 return InferTupleSelectShape(lhs, rhs, ehs);
1016 default:
1017 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1018 }
1019 }
1020
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1021 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1022 HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1023 std::vector<const Shape*> operand_shapes;
1024 operand_shapes.reserve(operands.size());
1025 for (const HloInstruction* operand : operands) {
1026 operand_shapes.push_back(&operand->shape());
1027 }
1028 return InferVariadicOpShape(opcode, operand_shapes);
1029 }
1030
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1031 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1032 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1033 for (const Shape* shape : operand_shapes) {
1034 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1035 }
1036 switch (opcode) {
1037 case HloOpcode::kTuple: {
1038 Shape result = ShapeUtil::MakeTupleShape({});
1039 result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1040 for (const Shape* shape : operand_shapes) {
1041 ShapeUtil::AppendShapeToTuple(*shape, &result);
1042 }
1043 return result;
1044 }
1045 case HloOpcode::kSort: {
1046 if (operand_shapes.size() == 1) {
1047 return *operand_shapes[0];
1048 } else {
1049 for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
1050 if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1051 *operand_shapes[operand])) {
1052 return InvalidArgument(
1053 "Sort keys and values dimensions must match. "
1054 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1055 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1056 ShapeUtil::HumanString(*operand_shapes[operand]));
1057 }
1058 }
1059 std::vector<Shape> operand_shape_values;
1060 for (const Shape* operand_shape : operand_shapes) {
1061 operand_shape_values.push_back(*operand_shape);
1062 }
1063 return ShapeUtil::MakeTupleShape(operand_shape_values);
1064 }
1065 return InvalidArgument("Unexpected number of operands for sort");
1066 }
1067 default:
1068 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1069 }
1070 }
1071
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64> dimensions)1072 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1073 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1074 absl::Span<const int64> dimensions) {
1075 if (arg_shapes.empty()) {
1076 return InvalidArgument("Map expects at least one argument.");
1077 }
1078
1079 // All arguments must have the same shape.
1080 const Shape* arg_shape = arg_shapes[0];
1081 for (size_t i = 1; i < arg_shapes.size(); ++i) {
1082 TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1083
1084 if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
1085 continue;
1086 }
1087 if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1088 *arg_shape)) {
1089 if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1090 continue;
1091 }
1092 if (ShapeUtil::IsScalar(*arg_shape)) {
1093 arg_shape = arg_shapes[i];
1094 continue;
1095 }
1096 }
1097
1098 std::vector<string> pieces;
1099 for (const Shape* shape : arg_shapes) {
1100 pieces.push_back(ShapeUtil::HumanString(*shape));
1101 }
1102 return InvalidArgument(
1103 "Map operation requires all operands to have the same shape; got: "
1104 "%s.",
1105 StrJoin(pieces, ", "));
1106 }
1107
1108 // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1109 // only support mapping across all dimensions: i.e. scalar map functions).
1110 if (dimensions.size() != arg_shape->dimensions_size()) {
1111 return InvalidArgument(
1112 "Map applied to a subset of dimensions currently not supported: "
1113 "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1114 arg_shape->dimensions_size(), dimensions.size());
1115 }
1116
1117 // Check that requested map dimensions numbers are monotonically increasing.
1118 for (int i = 0; i < dimensions.size(); ++i) {
1119 if (dimensions[i] != i) {
1120 return InvalidArgument(
1121 "Map requires monotonically increasing dimension numbers; got: %s.",
1122 StrJoin(dimensions, ", "));
1123 }
1124 }
1125
1126 // The applied function's arity equals the number of arguments.
1127 if (arg_shapes.size() != to_apply.parameters_size()) {
1128 return InvalidArgument(
1129 "Map applied function arity must match number of arguments; got: "
1130 "arity: %d, arguments: %u.",
1131 to_apply.parameters_size(), arg_shapes.size());
1132 }
1133
1134 // The parameters should all be scalars, and the output too.
1135 const Shape& output_shape = to_apply.result();
1136 if (!ShapeUtil::IsScalar(output_shape)) {
1137 return InvalidArgument(
1138 "Mapped computation's result has to be a scalar; got: %s.",
1139 ShapeUtil::HumanString(output_shape));
1140 }
1141
1142 for (int i = 0; i < to_apply.parameters_size(); ++i) {
1143 const Shape& parameter_shape = to_apply.parameters(i);
1144
1145 if (!ShapeUtil::IsScalar(parameter_shape)) {
1146 return InvalidArgument(
1147 "Mapped computation's parameter has to be a scalar; "
1148 "got parameter %d shape: %s.",
1149 i, ShapeUtil::HumanString(parameter_shape));
1150 }
1151
1152 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1153 *arg_shape)) {
1154 return InvalidArgument(
1155 "Mapped computation's parameter type has to match argument element "
1156 "type; got parameter %d shape: %s, argument shape: %s.",
1157 i, ShapeUtil::HumanString(parameter_shape),
1158 ShapeUtil::HumanString(*arg_shape));
1159 }
1160 }
1161
1162 return ShapeUtil::MakeShape(output_shape.element_type(),
1163 AsInt64Slice(arg_shape->dimensions()));
1164 }
1165
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64 feature_index)1166 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1167 const Shape& operand_shape, const Shape& scale_shape,
1168 const Shape& offset_shape, int64 feature_index) {
1169 TF_RETURN_IF_ERROR(
1170 ExpectArray(operand_shape, "operand of batch norm training"));
1171 TF_RETURN_IF_ERROR(
1172 ExpectArray(offset_shape, "offset input of batch norm training"));
1173 TF_RETURN_IF_ERROR(
1174 ExpectArray(scale_shape, "scale input of batch norm training"));
1175
1176 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1177 Status::OK());
1178 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1179 Status::OK());
1180 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1181 Status::OK());
1182
1183 if (feature_index >= operand_shape.rank()) {
1184 return InvalidArgument(
1185 "Expected feature_index of batch-norm-training to be "
1186 "smaller than the rank of operand_shape; "
1187 "got feature_index %d, and rank %d.",
1188 feature_index, operand_shape.rank());
1189 }
1190
1191 if (feature_index < 0) {
1192 return InvalidArgument(
1193 "Expected feature_index of batch-norm-training to "
1194 "be a non-negative number, got %d.",
1195 feature_index);
1196 }
1197
1198 if (operand_shape.rank() < 1) {
1199 return InvalidArgument(
1200 "Expected the rank of operand to "
1201 "batch-norm-training to be at least 1; got %d.",
1202 operand_shape.rank());
1203 }
1204
1205 if (offset_shape.rank() != 1) {
1206 return InvalidArgument(
1207 "Offset input of batch-norm-training must have"
1208 " rank 1, but has rank %d.",
1209 offset_shape.rank());
1210 }
1211
1212 if (scale_shape.rank() != 1) {
1213 return InvalidArgument(
1214 "Scale input of batch-norm-training must have"
1215 " rank 1, but has rank %d.",
1216 scale_shape.rank());
1217 }
1218
1219 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1220 return InvalidArgument(
1221 "The operand to batch-norm-training must have a floating point "
1222 "element type, but the shape is %s.",
1223 PrimitiveType_Name(operand_shape.element_type()));
1224 }
1225
1226 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1227 operand_shape)) {
1228 return InvalidArgument(
1229 "The inputs should have the same element type for batch-norm-training, "
1230 "but the shape of offset factor is %s "
1231 "and the shape of operand is %s.",
1232 PrimitiveType_Name(offset_shape.element_type()),
1233 PrimitiveType_Name(operand_shape.element_type()));
1234 }
1235
1236 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1237 operand_shape)) {
1238 return InvalidArgument(
1239 "The inputs should have the same element type for batch-norm-training, "
1240 "but the shape of scale factor is %s "
1241 "and the shape of operand is %s.",
1242 PrimitiveType_Name(scale_shape.element_type()),
1243 PrimitiveType_Name(operand_shape.element_type()));
1244 }
1245
1246 const int64 feature_count = operand_shape.dimensions(feature_index);
1247 Shape output_shape_for_mean_and_var =
1248 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1249
1250 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1251 return InvalidArgument(
1252 "The size of offset factor should be the same as feature count,"
1253 "but the size of offset factor is %d "
1254 "and the feature count is %d.",
1255 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1256 }
1257
1258 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1259 return InvalidArgument(
1260 "The size of scale factor should be the same as feature count,"
1261 "but the size of scale factor is %d "
1262 "and the feature count is %d.",
1263 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1264 }
1265
1266 return ShapeUtil::MakeTupleShape({operand_shape,
1267 output_shape_for_mean_and_var,
1268 output_shape_for_mean_and_var});
1269 }
1270
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64 feature_index)1271 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1272 const Shape& operand_shape, const Shape& scale_shape,
1273 const Shape& offset_shape, const Shape& mean_shape,
1274 const Shape& variance_shape, int64 feature_index) {
1275 TF_RETURN_IF_ERROR(
1276 ExpectArray(operand_shape, "operand of batch norm inference"));
1277 TF_RETURN_IF_ERROR(
1278 ExpectArray(offset_shape, "offset input of batch norm inference"));
1279 TF_RETURN_IF_ERROR(
1280 ExpectArray(scale_shape, "scale input of batch norm inference"));
1281
1282 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1283 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1284 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1285 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1286 TF_RETURN_IF_ERROR(
1287 ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1288
1289 if (feature_index >= operand_shape.rank()) {
1290 return InvalidArgument(
1291 "Expected feature_index of batch-norm-inference to be "
1292 "smaller than the rank of operand_shape; "
1293 "got feature_index %d, and rank %d.",
1294 feature_index, operand_shape.rank());
1295 }
1296
1297 if (feature_index < 0) {
1298 return InvalidArgument(
1299 "Expected feature_index of batch-norm-inference to "
1300 "be a non-negative number, got %d.",
1301 feature_index);
1302 }
1303
1304 if (operand_shape.rank() < 1) {
1305 return InvalidArgument(
1306 "Expected the rank of operand to "
1307 "batch-norm-inference to be at least 1; got %d.",
1308 operand_shape.rank());
1309 }
1310
1311 if (offset_shape.rank() != 1) {
1312 return InvalidArgument(
1313 "Offset input of batch-norm-inference must have"
1314 " rank 1, but has rank %d.",
1315 offset_shape.rank());
1316 }
1317
1318 if (scale_shape.rank() != 1) {
1319 return InvalidArgument(
1320 "Scale input of batch-norm-inference must have"
1321 " rank 1, but has rank %d.",
1322 scale_shape.rank());
1323 }
1324
1325 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1326 return InvalidArgument(
1327 "The operand to batch-norm-inference must have a floating point "
1328 "element type, but the shape is %s.",
1329 PrimitiveType_Name(operand_shape.element_type()));
1330 }
1331
1332 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1333 operand_shape)) {
1334 return InvalidArgument(
1335 "The inputs should have the same element type for "
1336 "batch-norm-inference, "
1337 "but the shape of offset factor is %s "
1338 "and the shape of operand is %s.",
1339 PrimitiveType_Name(offset_shape.element_type()),
1340 PrimitiveType_Name(operand_shape.element_type()));
1341 }
1342
1343 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1344 operand_shape)) {
1345 return InvalidArgument(
1346 "The inputs should have the same element type for "
1347 "batch-norm-inference, "
1348 "but the shape of scale factor is %s "
1349 "and the shape of operand is %s.",
1350 PrimitiveType_Name(scale_shape.element_type()),
1351 PrimitiveType_Name(operand_shape.element_type()));
1352 }
1353
1354 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1355 operand_shape)) {
1356 return InvalidArgument(
1357 "The inputs should have the same element type for "
1358 "batch-norm-inference, "
1359 "but the shape of mean is %s "
1360 "and the shape of operand is %s.",
1361 PrimitiveType_Name(mean_shape.element_type()),
1362 PrimitiveType_Name(operand_shape.element_type()));
1363 }
1364
1365 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1366 operand_shape)) {
1367 return InvalidArgument(
1368 "The inputs should have the same element type for "
1369 "batch-norm-inference, "
1370 "but the shape of variance is %s "
1371 "and the shape of operand is %s.",
1372 PrimitiveType_Name(mean_shape.element_type()),
1373 PrimitiveType_Name(variance_shape.element_type()));
1374 }
1375
1376 const int64 feature_count = operand_shape.dimensions(feature_index);
1377 Shape output_shape_for_mean_and_var =
1378 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1379
1380 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1381 return InvalidArgument(
1382 "The size of offset factor should be the same as feature count,"
1383 "but the size of offset factor is %d "
1384 "and the feature count is %d.",
1385 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1386 }
1387
1388 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1389 return InvalidArgument(
1390 "The size of scale factor should be the same as feature count,"
1391 "but the size of scale factor is %d "
1392 "and the feature count is %d.",
1393 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1394 }
1395
1396 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1397 return InvalidArgument(
1398 "The size of mean should be the same as feature count,"
1399 "but the size of mean is %d "
1400 "and the feature count is %d.",
1401 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1402 }
1403
1404 if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1405 return InvalidArgument(
1406 "The size of variance should be the same as feature count,"
1407 "but the size of variance is %d "
1408 "and the feature count is %d.",
1409 ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1410 }
1411
1412 return operand_shape;
1413 }
1414
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64 feature_index)1415 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1416 const Shape& operand_shape, const Shape& scale_shape,
1417 const Shape& mean_shape, const Shape& var_shape,
1418 const Shape& output_grad_shape, int64 feature_index) {
1419 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1420 TF_RETURN_IF_ERROR(
1421 ExpectArray(scale_shape, "scale input of batch norm grad"));
1422 TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1423 TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1424 TF_RETURN_IF_ERROR(
1425 ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1426
1427 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1428 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1429 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1430 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1431 TF_RETURN_IF_ERROR(
1432 ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1433
1434 if (feature_index >= operand_shape.rank()) {
1435 return InvalidArgument(
1436 "Expected feature_index of batch-norm-grad to be "
1437 "smaller than the rank of operand_shape; "
1438 "got feature_index %d, and rank %d.",
1439 feature_index, operand_shape.rank());
1440 }
1441
1442 if (operand_shape.rank() != output_grad_shape.rank()) {
1443 return InvalidArgument(
1444 "Expected operand_shape of batch-norm-grad to have the same rank as"
1445 " output_grad_shape; got rank(oprand_shape) %d, and"
1446 " rank(output_grad_shape) %d.",
1447 operand_shape.rank(), output_grad_shape.rank());
1448 }
1449
1450 if (mean_shape.rank() != 1) {
1451 return InvalidArgument(
1452 "Mean input of batch-norm-grad must have"
1453 " rank 1, but has rank %d.",
1454 mean_shape.rank());
1455 }
1456
1457 if (scale_shape.rank() != 1) {
1458 return InvalidArgument(
1459 "Scale input of batch-norm-grad must have"
1460 " rank 1, but has rank %d.",
1461 scale_shape.rank());
1462 }
1463
1464 if (var_shape.rank() != 1) {
1465 return InvalidArgument(
1466 "Var input of batch-norm-grad must have"
1467 " rank 1, but has rank %d.",
1468 var_shape.rank());
1469 }
1470
1471 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1472 return InvalidArgument(
1473 "The operand to batch-norm-grad must have a floating point "
1474 "element type, but the shape is %s.",
1475 PrimitiveType_Name(operand_shape.element_type()));
1476 }
1477
1478 if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1479 return InvalidArgument(
1480 "The output_grad to batch-norm-grad must have a floating point "
1481 "element type, but the shape is %s.",
1482 PrimitiveType_Name(output_grad_shape.element_type()));
1483 }
1484
1485 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1486 operand_shape)) {
1487 return InvalidArgument(
1488 "The inputs should have the same element type for batch-norm-grad, "
1489 "but the element type of output_grad is %s "
1490 "and the element type of operand is %s.",
1491 PrimitiveType_Name(output_grad_shape.element_type()),
1492 PrimitiveType_Name(operand_shape.element_type()));
1493 }
1494
1495 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1496 operand_shape)) {
1497 return InvalidArgument(
1498 "The inputs should have the same element type for batch-norm-grad, "
1499 "but the element type of scale factor is %s "
1500 "and the element type of operand is %s.",
1501 PrimitiveType_Name(scale_shape.element_type()),
1502 PrimitiveType_Name(operand_shape.element_type()));
1503 }
1504
1505 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1506 operand_shape)) {
1507 return InvalidArgument(
1508 "The inputs should have the same element type for batch-norm-grad, "
1509 "but the element type of mean is %s "
1510 "and the element type of operand is %s.",
1511 PrimitiveType_Name(mean_shape.element_type()),
1512 PrimitiveType_Name(operand_shape.element_type()));
1513 }
1514
1515 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1516 operand_shape)) {
1517 return InvalidArgument(
1518 "The inputs should have the same element type for batch-norm-grad, "
1519 "but the element type of mean is %s "
1520 "and the element type of operand is %s.",
1521 PrimitiveType_Name(mean_shape.element_type()),
1522 PrimitiveType_Name(operand_shape.element_type()));
1523 }
1524
1525 const int64 feature_count = operand_shape.dimensions(feature_index);
1526
1527 Shape feature_shape =
1528 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1529
1530 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1531 return InvalidArgument(
1532 "The size of mean should be the same as feature count,"
1533 "but the size of offset factor is %d "
1534 "and the feature count is %d.",
1535 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1536 }
1537
1538 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1539 return InvalidArgument(
1540 "The size of scale factor should be the same as feature count,"
1541 "but the size of scale factor is %d "
1542 "and the feature count is %d.",
1543 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1544 }
1545
1546 if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1547 return InvalidArgument(
1548 "The size of variance should be the same as feature count,"
1549 "but the size of variance is %d "
1550 "and the feature count is %d.",
1551 ShapeUtil::GetDimension(var_shape, 0), feature_count);
1552 }
1553
1554 // Verify operand_shape and output_grad_shape have same bounds.
1555 for (int64 i = 0; i < operand_shape.rank(); ++i) {
1556 if (ShapeUtil::GetDimension(operand_shape, i) !=
1557 ShapeUtil::GetDimension(output_grad_shape, i)) {
1558 return InvalidArgument(
1559 "The bounds of operand shape should be the same as output_grad's,"
1560 "but the bound of operand_shape at dimension %d is %d "
1561 "and the bound of output_grad_shape is %d.",
1562 i, ShapeUtil::GetDimension(operand_shape, i),
1563 ShapeUtil::GetDimension(output_grad_shape, i));
1564 }
1565 }
1566
1567 return ShapeUtil::MakeTupleShape(
1568 {operand_shape, feature_shape, feature_shape});
1569 }
1570
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums)1571 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1572 const Shape& lhs, const Shape& rhs, int64 feature_group_count,
1573 int64 batch_group_count, const Window& window,
1574 const ConvolutionDimensionNumbers& dnums) {
1575 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1576 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1577
1578 if (feature_group_count <= 0) {
1579 return InvalidArgument(
1580 "feature_group_count must be a positive number, got %d",
1581 feature_group_count);
1582 }
1583
1584 if (batch_group_count <= 0) {
1585 return InvalidArgument(
1586 "batch_group_count must be a positive number, got %d",
1587 batch_group_count);
1588 }
1589
1590 if (batch_group_count > 1 && feature_group_count > 1) {
1591 return InvalidArgument(
1592 "both batch_group_count %d and feature_group_count %d cannot be "
1593 "greater than 1",
1594 batch_group_count, feature_group_count);
1595 }
1596
1597 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
1598 return InvalidArgument(
1599 "Convolution with different element types: %s and %s.",
1600 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
1601 }
1602 if (dnums.input_spatial_dimensions_size() !=
1603 dnums.kernel_spatial_dimensions_size()) {
1604 return InvalidArgument(
1605 "Both arguments to convolution must have same number of dimensions.\n"
1606 "Numbers: %s",
1607 dnums.DebugString());
1608 }
1609
1610 if (dnums.input_spatial_dimensions_size() !=
1611 dnums.output_spatial_dimensions_size()) {
1612 return InvalidArgument(
1613 "Both input and output of convolution must have same number of "
1614 "dimensions.\nNumbers: %s",
1615 dnums.DebugString());
1616 }
1617
1618 const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1619 if (window.dimensions_size() != num_spatial_dims) {
1620 return InvalidArgument(
1621 "Window must have same number of dimensions as dimension numbers.\n"
1622 "Window: %s\nDimension numbers: %s.",
1623 window.DebugString(), dnums.DebugString());
1624 }
1625
1626 const int num_dims = num_spatial_dims + 2;
1627 if (lhs.rank() != num_dims) {
1628 return InvalidArgument(
1629 "The LHS argument to a convolution should have rank %d; lhs: %s.",
1630 num_dims, ShapeUtil::HumanString(lhs));
1631 }
1632 if (rhs.rank() != num_dims) {
1633 return InvalidArgument(
1634 "The RHS argument to a convolution should have rank %d; rhs: %s.",
1635 num_dims, ShapeUtil::HumanString(rhs));
1636 }
1637 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1638 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1639
1640 // Verifies that the input and window dimensions are a permutation of
1641 // the dimension numbers.
1642 std::vector<int64> input_dnums(num_dims);
1643 input_dnums[0] = dnums.input_batch_dimension();
1644 input_dnums[1] = dnums.input_feature_dimension();
1645 std::copy(dnums.input_spatial_dimensions().begin(),
1646 dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1647 absl::c_sort(input_dnums);
1648
1649 std::vector<int64> window_dnums(num_dims);
1650 window_dnums[0] = dnums.kernel_input_feature_dimension();
1651 window_dnums[1] = dnums.kernel_output_feature_dimension();
1652 std::copy(dnums.kernel_spatial_dimensions().begin(),
1653 dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1654 absl::c_sort(window_dnums);
1655
1656 std::vector<int64> output_dnums(num_dims);
1657 output_dnums[0] = dnums.output_batch_dimension();
1658 output_dnums[1] = dnums.output_feature_dimension();
1659 std::copy(dnums.output_spatial_dimensions().begin(),
1660 dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1661 absl::c_sort(output_dnums);
1662
1663 std::vector<int64> expected_dnums(num_dims);
1664 std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1665
1666 const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
1667 if (!absl::c_all_of(input_dnums, in_range) ||
1668 !absl::c_all_of(window_dnums, in_range) ||
1669 !absl::c_all_of(output_dnums, in_range)) {
1670 return InvalidArgument(
1671 "A dimension number is out of range in convolution: %s.",
1672 dnums.DebugString());
1673 }
1674
1675 if (input_dnums != expected_dnums) {
1676 return InvalidArgument(
1677 "Input dimensions of convolution must contain each dimension exactly "
1678 "once: %s.",
1679 dnums.DebugString());
1680 }
1681 if (window_dnums != expected_dnums) {
1682 return InvalidArgument(
1683 "Window dimensions of convolution must contain each dimension exactly "
1684 "once: %s.",
1685 dnums.DebugString());
1686 }
1687 if (output_dnums != expected_dnums) {
1688 return InvalidArgument(
1689 "Output dimensions of convolution must contain each dimension exactly "
1690 "once: %s.",
1691 dnums.DebugString());
1692 }
1693
1694 std::vector<int64> input_spatial_dims(num_spatial_dims);
1695 for (int i = 0; i < num_spatial_dims; ++i) {
1696 input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1697 }
1698 const int64 input_features = lhs.dimensions(dnums.input_feature_dimension());
1699 const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension());
1700
1701 std::vector<int64> kernel_spatial_dims(num_spatial_dims);
1702 for (int i = 0; i < num_spatial_dims; ++i) {
1703 kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1704 }
1705 const int64 kernel_input_features =
1706 rhs.dimensions(dnums.kernel_input_feature_dimension());
1707 const int64 kernel_output_features =
1708 rhs.dimensions(dnums.kernel_output_feature_dimension());
1709
1710 if (batch_group_count > 1 && input_batch % kernel_output_features != 0) {
1711 return InvalidArgument(
1712 "Expected input batch (value %d) to be divisible by output feature "
1713 "dimension size (value %d) for batch group count %d; "
1714 "got <conv>(%s, %s)\n"
1715 "Dimension numbers: {%s}.",
1716 input_batch, kernel_output_features, batch_group_count,
1717 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1718 dnums.DebugString());
1719 }
1720
1721 if (input_features % feature_group_count != 0 ||
1722 input_features / feature_group_count != kernel_input_features) {
1723 return InvalidArgument(
1724 "Expected LHS feature dimension (value %d) to be a multiple of "
1725 "feature_group_count (value %d), and LHS feature dimension / "
1726 "feature_group_count = RHS feature dimension (value %d); "
1727 "got <conv>(%s, %s)\n"
1728 "Dimension numbers: {%s}.",
1729 input_features, feature_group_count, kernel_input_features,
1730 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1731 dnums.DebugString());
1732 }
1733
1734 if (kernel_output_features % feature_group_count > 0) {
1735 // A depthwise/grouped filter has the shape
1736 // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1737 // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1738 // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1739 // feature count (which is equal to kernel output features) has to be a
1740 // multiple of feature_group_count.
1741 return InvalidArgument(
1742 "Expected output feature dimension (value %d) to be divisible by "
1743 "feature_group_count (value %d); "
1744 "got <conv>(%s, %s)\n"
1745 "Dimension numbers: {%s}.",
1746 kernel_output_features, feature_group_count,
1747 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1748 dnums.DebugString());
1749 }
1750
1751 if (input_batch % batch_group_count > 0) {
1752 return InvalidArgument(
1753 "Expected input batch dimension (value %d) to be divisible by "
1754 "batch_group_count (value %d); "
1755 "got <conv>(%s, %s)\n"
1756 "Dimension numbers: {%s}.",
1757 input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1758 ShapeUtil::HumanString(rhs), dnums.DebugString());
1759 }
1760
1761 std::vector<int64> window_dims(num_spatial_dims);
1762 for (int i = 0; i < num_spatial_dims; ++i) {
1763 window_dims[i] = window.dimensions(i).size();
1764 }
1765 if (kernel_spatial_dims != window_dims) {
1766 return InvalidArgument(
1767 "Window dimensions do not match RHS shape:\n\t"
1768 "RHS shape: %s\n\t"
1769 "Window: {%s}\n\t"
1770 "Dimension numbers: {%s}.",
1771 ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1772 dnums.ShortDebugString());
1773 }
1774
1775 Shape base_shape =
1776 ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1777 TF_ASSIGN_OR_RETURN(
1778 Shape window_output_shape,
1779 InferWindowOutputShape(base_shape, window, lhs.element_type(),
1780 /*allow_negative_padding=*/true));
1781
1782 std::vector<int64> dimensions(num_dims);
1783 dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1784 dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1785 for (int i = 0; i < num_spatial_dims; ++i) {
1786 dimensions[dnums.output_spatial_dimensions(i)] =
1787 window_output_shape.dimensions(i);
1788 }
1789 std::vector<bool> is_dynamic(num_dims);
1790 for (int i = 0; i < num_dims; i++) {
1791 if (lhs.is_dynamic_dimension(i)) {
1792 if (i == dnums.input_batch_dimension()) {
1793 is_dynamic[dnums.output_batch_dimension()] = true;
1794 } else if (i == dnums.input_feature_dimension()) {
1795 // Input feature dimension is a contracting dimension, which does not
1796 // affect the output dimension size. So we need to do nothing.
1797 } else {
1798 return InvalidArgument(
1799 "Dynamic Spatial Convolution is not supported: lhs shape is %s ",
1800 lhs.ToString());
1801 }
1802 }
1803 if (rhs.is_dynamic_dimension(i)) {
1804 if (i == dnums.kernel_input_feature_dimension()) {
1805 // Kernel feature dimension does not affect the output dimension size.
1806 // So we need to do nothing.
1807 } else {
1808 return InvalidArgument(
1809 "Dynamic Spatial Convolution is not supported: rhs shape is %s ",
1810 rhs.ToString());
1811 }
1812 }
1813 }
1814 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1815 dimensions, is_dynamic);
1816 }
1817
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64> fft_length)1818 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1819 const Shape& in, const FftType fft_type,
1820 const absl::Span<const int64> fft_length) {
1821 const int64 fft_rank = fft_length.size();
1822 if (fft_rank < 1 || fft_rank > 3) {
1823 return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1824 }
1825 #define RET_CHECK_RANK(x) \
1826 if (x.dimensions_size() < fft_rank) { \
1827 return InvalidArgument( \
1828 "FFT of rank %d requires input of at least " \
1829 "same rank; got input of rank %d", \
1830 fft_rank, x.dimensions_size()); \
1831 }
1832 switch (fft_type) {
1833 case FFT:
1834 case IFFT:
1835 if (in.element_type() != C64) {
1836 return InvalidArgument("%s requires complex input type, found %s.",
1837 FftType_Name(fft_type),
1838 PrimitiveType_Name(in.element_type()));
1839 }
1840 RET_CHECK_RANK(in);
1841 return in;
1842 case RFFT: {
1843 if (in.element_type() != F32) {
1844 return InvalidArgument("RFFT requires F32 input type, found %s.",
1845 PrimitiveType_Name(in.element_type()));
1846 }
1847 RET_CHECK_RANK(in);
1848 for (int i = 0; i < fft_rank; i++) {
1849 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1850 fft_length[i]) {
1851 return InvalidArgument(
1852 "RFFT requires innermost dimensions match fft_length but "
1853 "dimension %d is %d and should be %d.",
1854 in.dimensions_size() - fft_rank + i,
1855 in.dimensions(in.dimensions_size() - fft_rank + i),
1856 fft_length[i]);
1857 }
1858 }
1859 if (ShapeUtil::IsZeroElementArray(in)) {
1860 return in;
1861 }
1862 Shape result = ShapeUtil::ChangeElementType(in, C64);
1863 result.set_dimensions(result.dimensions_size() - 1,
1864 fft_length[fft_rank - 1] / 2 + 1);
1865 return result;
1866 }
1867 case IRFFT: {
1868 if (in.element_type() != C64) {
1869 return InvalidArgument("IRFFT requires C64 input type, found %s.",
1870 PrimitiveType_Name(in.element_type()));
1871 }
1872 RET_CHECK_RANK(in);
1873 Shape result = ShapeUtil::ComplexComponentShape(in);
1874 for (int i = 0; i < fft_rank - 1; i++) {
1875 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1876 fft_length[i]) {
1877 return InvalidArgument(
1878 "IRFFT requires all but one innermost dimensions match "
1879 "fft_length, but dimension %d is %d and should be %d.",
1880 in.dimensions_size() - fft_rank + i,
1881 in.dimensions(in.dimensions_size() - fft_rank + i),
1882 fft_length[i]);
1883 }
1884 }
1885 if (in.dimensions(in.dimensions_size() - 1) !=
1886 fft_length[fft_rank - 1] / 2 + 1) {
1887 return InvalidArgument(
1888 "IRFFT requires innermost dimension matches fft_length/2+1, but "
1889 "dimension %d is %d and should be %d.",
1890 in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1891 fft_length[fft_rank - 1] / 2 + 1);
1892 }
1893 result.set_dimensions(result.dimensions_size() - 1,
1894 fft_length[fft_rank - 1]);
1895 return result;
1896 }
1897 default:
1898 LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1899 }
1900 #undef RET_CHECK_RANK
1901 }
1902
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1903 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1904 const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1905 if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1906 a.element_type() != b.element_type()) {
1907 return InvalidArgument(
1908 "Expected element types in shape to be floating or complex and "
1909 "identical for TriangularSolve; got %s and %s.",
1910 PrimitiveType_Name(a.element_type()),
1911 PrimitiveType_Name(b.element_type()));
1912 }
1913 if (a.rank() < 2) {
1914 return InvalidArgument(
1915 "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1916 a.ToString());
1917 }
1918 if (b.rank() != a.rank()) {
1919 return InvalidArgument(
1920 "Arguments to triangular solve must have equal rank; got %s and %s.",
1921 b.ToString(), a.ToString());
1922 }
1923 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1924 return InvalidArgument(
1925 "The two minor dimensions of 'a' must have equal size, got %s.",
1926 a.ToString());
1927 }
1928 if (a.dimensions(a.rank() - 1) !=
1929 b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1930 return InvalidArgument(
1931 "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1932 "%s",
1933 a.ToString(), b.ToString());
1934 }
1935 absl::Span<const int64> a_batch_dims(a.dimensions());
1936 absl::Span<const int64> b_batch_dims(b.dimensions());
1937 a_batch_dims.remove_suffix(2);
1938 b_batch_dims.remove_suffix(2);
1939 if (a_batch_dims != b_batch_dims) {
1940 return InvalidArgument(
1941 "The leading batch dimensions of the arguments to triangular solve "
1942 "must be equal; got %s and %s.",
1943 b.ToString(), a.ToString());
1944 }
1945 if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
1946 options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
1947 return InvalidArgument(
1948 "Invalid transpose option value for triangular solve (%d).\n",
1949 options.transpose_a());
1950 }
1951 return b;
1952 }
1953
InferCholeskyShape(const Shape & a)1954 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
1955 const Shape& a) {
1956 if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
1957 return InvalidArgument(
1958 "Expected element type in shape to be floating or complex for "
1959 "Cholesky; got %s.",
1960 PrimitiveType_Name(a.element_type()));
1961 }
1962 if (a.rank() < 2) {
1963 return InvalidArgument(
1964 "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
1965 a.ToString());
1966 }
1967 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1968 return InvalidArgument(
1969 "The two minor dimensions of 'a' must have equal size, got %s.",
1970 a.ToString());
1971 }
1972 return a;
1973 }
1974
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)1975 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
1976 absl::Span<const Shape* const> operand_shapes) {
1977 for (const Shape* operand_shape : operand_shapes) {
1978 TF_RETURN_IF_ERROR(
1979 ExpectArray(*operand_shape, "operand of cross replica sum"));
1980 }
1981 if (operand_shapes.size() == 1) {
1982 return *operand_shapes[0];
1983 }
1984 std::vector<Shape> operand_shape_values;
1985 for (const Shape* operand_shape : operand_shapes) {
1986 operand_shape_values.push_back(*operand_shape);
1987 }
1988 return ShapeUtil::MakeTupleShape(operand_shape_values);
1989 }
1990
InferAllToAllShape(const Shape & shape,int64 split_dimension,int64 concat_dimension,int64 split_count)1991 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
1992 const Shape& shape, int64 split_dimension, int64 concat_dimension,
1993 int64 split_count) {
1994 TF_RET_CHECK(split_count > 0);
1995 if (split_dimension >= shape.rank() || split_dimension < 0) {
1996 return InvalidArgument(
1997 "AllToAll split_dimension %d is out-of-bounds in shape %s.",
1998 split_dimension, ShapeUtil::HumanString(shape));
1999 }
2000 if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2001 return InvalidArgument(
2002 "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2003 concat_dimension, ShapeUtil::HumanString(shape));
2004 }
2005 if (shape.dimensions(split_dimension) % split_count != 0) {
2006 return InvalidArgument(
2007 "AllToAll split dimension size %d must be dividable by split_count "
2008 "%d.",
2009 shape.dimensions(split_dimension), split_count);
2010 }
2011 std::vector<int64> new_dimensions(shape.dimensions().begin(),
2012 shape.dimensions().end());
2013 new_dimensions[split_dimension] /= split_count;
2014 new_dimensions[concat_dimension] *= split_count;
2015 return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2016 }
2017
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2018 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2019 absl::Span<const Shape* const> operand_shapes) {
2020 // An Alltoall HLO instruction receives N operands (with the same shape) and
2021 // returns a tuple that contains N array shapes.
2022 TF_RET_CHECK(!operand_shapes.empty());
2023 for (int i = 0; i < operand_shapes.size(); i++) {
2024 if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) {
2025 return InvalidArgument(
2026 "HLO all-to-all has operands with different shapes: the 0th "
2027 "operand shape %s, but the %dth operand has shape %s.",
2028 ShapeUtil::HumanString(*operand_shapes[0]), i,
2029 ShapeUtil::HumanString(*operand_shapes[i]));
2030 }
2031 }
2032
2033 return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2034 }
2035
InferCollectivePermuteShape(const Shape & shape)2036 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2037 const Shape& shape) {
2038 TF_RET_CHECK(shape.IsArray());
2039 return shape;
2040 }
2041
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64> dimensions_to_reduce,const ProgramShape & to_apply)2042 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2043 absl::Span<const Shape* const> arg_shapes,
2044 absl::Span<const int64> dimensions_to_reduce,
2045 const ProgramShape& to_apply) {
2046 if (arg_shapes.empty()) {
2047 return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2048 }
2049 if (arg_shapes.size() % 2) {
2050 return InvalidArgument(
2051 "Reduce must have an even number of arguments, has %lu",
2052 arg_shapes.size());
2053 }
2054 int64 num_reduced_args = arg_shapes.size() / 2;
2055
2056 auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2057 // Check that all of the reduced tensors have the same dimensions. The element
2058 // types may be different.
2059 for (int64 i = 1; i < num_reduced_args; ++i) {
2060 if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2061 return InvalidArgument(
2062 "All reduced tensors must have the same dimension. Tensor 0 has "
2063 "shape %s, Tensor %d has shape %s",
2064 ShapeUtil::HumanString(*reduced_args[0]), i,
2065 ShapeUtil::HumanString(*reduced_args[i]));
2066 }
2067 }
2068
2069 // Check that the dimensions to reduce are in-bounds for the given shape.
2070 // We've already verified all reduced tensors have the same dimensions, so it
2071 // doesn't matter which one we choose.
2072 const Shape& arg = *reduced_args[0];
2073 for (int64 dimension : dimensions_to_reduce) {
2074 if (dimension >= arg.rank() || dimension < 0) {
2075 return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2076 dimension, ShapeUtil::HumanString(arg));
2077 }
2078 }
2079
2080 auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2081 std::vector<PrimitiveType> element_types;
2082 for (const Shape* arg : reduced_args) {
2083 element_types.push_back(arg->element_type());
2084 }
2085 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2086 num_reduced_args));
2087
2088 std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
2089 dimensions_to_reduce.end());
2090 std::vector<int64> new_dimensions;
2091 std::vector<bool> new_is_dynamic;
2092 for (int i = 0; i < arg.rank(); ++i) {
2093 if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2094 new_dimensions.push_back(arg.dimensions(i));
2095 new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2096 }
2097 }
2098
2099 if (ShapeUtil::IsScalar(to_apply.result())) {
2100 return ShapeUtil::MakeShape(to_apply.result().element_type(),
2101 new_dimensions, new_is_dynamic);
2102 } else {
2103 std::vector<Shape> result_subshapes;
2104 for (const Shape& subshape : to_apply.result().tuple_shapes()) {
2105 result_subshapes.push_back(ShapeUtil::MakeShape(
2106 subshape.element_type(), new_dimensions, new_is_dynamic));
2107 }
2108 return ShapeUtil::MakeTupleShape(result_subshapes);
2109 }
2110 }
2111
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2112 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2113 const Shape& operand_shape, const Shape& init_value_shape,
2114 const Window& window, const ProgramShape& to_apply_shape) {
2115 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2116 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2117 {operand_shape.element_type()},
2118 /*inputs=*/1));
2119 return InferWindowOutputShape(operand_shape, window,
2120 init_value_shape.element_type(),
2121 /*allow_negative_padding=*/false);
2122 }
2123
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2124 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2125 const Shape& operand_shape, const ProgramShape& select_shape,
2126 const Window& window, const Shape& source_shape,
2127 const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2128 TF_RETURN_IF_ERROR(
2129 ExpectArray(operand_shape, "operand of select-and-scatter"));
2130
2131 // Check if the select function has a proper shape of (T,T) -> PRED.
2132 if (select_shape.parameters_size() != 2) {
2133 return InvalidArgument(
2134 "Select function must take 2 parameters, but "
2135 "takes %d parameter(s).",
2136 select_shape.parameters_size());
2137 }
2138 const Shape& select_result_shape = select_shape.result();
2139 if (!ShapeUtil::Compatible(select_result_shape,
2140 ShapeUtil::MakeShape(PRED, {}))) {
2141 return InvalidArgument("Select function must have rank-0 PRED result.");
2142 }
2143 const Shape& operand_element_shape =
2144 ShapeUtil::MakeShape(operand_shape.element_type(), {});
2145 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2146 select_shape.parameters(0))) {
2147 return InvalidArgument(
2148 "Select function's first parameter shape currently must "
2149 "match the operand element shape, but got %s vs %s.",
2150 ShapeUtil::HumanString(select_shape.parameters(0)),
2151 ShapeUtil::HumanString(operand_element_shape));
2152 }
2153 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2154 select_shape.parameters(1))) {
2155 return InvalidArgument(
2156 "Select function's second parameter shape currently must "
2157 "match the operand element shape, but got %s vs %s.",
2158 ShapeUtil::HumanString(select_shape.parameters(1)),
2159 ShapeUtil::HumanString(operand_element_shape));
2160 }
2161
2162 // Check if the scatter function has a proper shape as a reduction.
2163 TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2164 {source_shape.element_type()},
2165 /*inputs=*/1));
2166
2167 // Check if the result shape of window operation matches the source shape.
2168 TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2169 InferWindowOutputShape(operand_shape, window,
2170 operand_shape.element_type(),
2171 /*allow_negative_padding=*/false));
2172 if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2173 window_result_shape)) {
2174 return InvalidArgument(
2175 "Source shape does not match the shape of window-reduced operand: "
2176 "source(%s), window-reduced operand(%s).",
2177 ShapeUtil::HumanString(source_shape),
2178 ShapeUtil::HumanString(window_result_shape));
2179 }
2180
2181 return operand_shape;
2182 }
2183
InferGetDimensionSizeShape(const Shape & shape,int64 dimension)2184 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2185 const Shape& shape, int64 dimension) {
2186 if (dimension < 0 || dimension >= shape.rank()) {
2187 return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2188 dimension);
2189 }
2190
2191 // TODO(b/119580730): Remove this restriction when very large dimension size
2192 // is needed.
2193 if (shape.dimensions(dimension) > std::numeric_limits<uint32>::max()) {
2194 return InvalidArgument(
2195 "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2196 "UINT_MAX limit.",
2197 ShapeUtil::HumanString(shape), dimension);
2198 }
2199
2200 return ShapeUtil::MakeShape(U32, {});
2201 }
2202
InferSliceShape(const Shape & arg,absl::Span<const int64> starts,absl::Span<const int64> limits,absl::Span<const int64> strides)2203 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2204 const Shape& arg, absl::Span<const int64> starts,
2205 absl::Span<const int64> limits, absl::Span<const int64> strides) {
2206 auto error = [&](const string& message) {
2207 return InvalidArgument(
2208 "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2209 "{%s}; strides: {%s}.",
2210 message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2211 StrJoin(limits, ","), StrJoin(strides, ","));
2212 };
2213 TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2214 VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2215 ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2216 StrJoin(limits, ", "));
2217
2218 if (starts.size() != limits.size()) {
2219 return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2220 starts.size(), limits.size()));
2221 }
2222
2223 if (starts.size() != strides.size()) {
2224 return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2225 starts.size(), strides.size()));
2226 }
2227
2228 if (starts.size() != arg.rank()) {
2229 return InvalidArgument(
2230 "Slice index count does not match argument rank: %u vs %d.",
2231 starts.size(), arg.rank());
2232 }
2233
2234 std::vector<int64> sizes;
2235 for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
2236 int64 start_index = starts[dimension];
2237 int64 limit_index = limits[dimension];
2238 int64 stride = strides[dimension];
2239 if (start_index < 0) {
2240 return InvalidArgument("Negative start index to slice: %d.", start_index);
2241 }
2242 if (limit_index > arg.dimensions(dimension)) {
2243 return error(
2244 StrFormat("limit index (%d) must be less than or equal to dimension "
2245 "size (%d)",
2246 limit_index, arg.dimensions(dimension)));
2247 }
2248 VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2249 VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2250 if (start_index > limit_index) {
2251 return error(
2252 StrFormat("limit index (%d) must be greater or equal to "
2253 "start index (%d) in slice with positive stride",
2254 limit_index, start_index));
2255 }
2256 if (stride <= 0) {
2257 return InvalidArgument("Stride (%d) must be positive.", stride);
2258 }
2259 sizes.push_back((limit_index - start_index + stride - 1) / stride);
2260 }
2261
2262 return ShapeUtil::MakeShape(arg.element_type(), sizes);
2263 }
2264
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64> slice_sizes,bool allow_scalar_indices)2265 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2266 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2267 absl::Span<const int64> slice_sizes, bool allow_scalar_indices) {
2268 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2269 auto number_of_indices = start_index_shapes.size();
2270 // TODO(b/118437727): Remove this path.
2271 if (!allow_scalar_indices ||
2272 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2273 if (number_of_indices != 1) {
2274 return InvalidArgument(
2275 "Dynamic slice should have exactly 1 index operand, has %d.",
2276 number_of_indices);
2277 }
2278
2279 const Shape& start_indices_shape = start_index_shapes[0];
2280 VLOG(2) << StrFormat(
2281 "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2282 ShapeUtil::HumanString(operand_shape),
2283 ShapeUtil::HumanString(start_indices_shape),
2284 StrJoin(slice_sizes, ", "));
2285
2286 TF_RETURN_IF_ERROR(
2287 ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2288
2289 if (start_indices_shape.rank() != 1) {
2290 return InvalidArgument(
2291 "Dynamic slice start indices of rank %d must be rank1.",
2292 start_indices_shape.rank());
2293 }
2294
2295 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2296 return InvalidArgument(
2297 "Dynamic slice start indices must be of integral type.");
2298 }
2299
2300 const int64 start_num_dims = start_indices_shape.dimensions(0);
2301 if (operand_shape.rank() != start_num_dims) {
2302 return InvalidArgument(
2303 "Dynamic slice start number of dimensions %d (%s) must match rank "
2304 "%d of slice input (%s).",
2305 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2306 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2307 }
2308 } else {
2309 VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2310 ShapeUtil::HumanString(operand_shape),
2311 StrJoin(slice_sizes, ", "));
2312
2313 if (operand_shape.rank() != number_of_indices) {
2314 return InvalidArgument(
2315 "Dynamic slice start number of dimensions %d must match rank "
2316 "%d of slice input (%s).",
2317 number_of_indices, operand_shape.rank(),
2318 ShapeUtil::HumanString(operand_shape));
2319 }
2320
2321 if (number_of_indices > 0) {
2322 const Shape& first_index_shape = start_index_shapes[0];
2323 if (!ShapeUtil::IsScalar(first_index_shape)) {
2324 return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2325 ShapeUtil::HumanString(first_index_shape));
2326 }
2327 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2328 return InvalidArgument(
2329 "Dynamic slice start indices must be of integral type.");
2330 }
2331 for (const Shape& index_shape : start_index_shapes) {
2332 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2333 return InvalidArgument(
2334 "Dynamic slice start indices must all have the same shape, got "
2335 "mismatching indices with shapes %s and %s.",
2336 ShapeUtil::HumanString(first_index_shape),
2337 ShapeUtil::HumanString(index_shape));
2338 }
2339 }
2340 }
2341 }
2342
2343 if (slice_sizes.size() != operand_shape.rank()) {
2344 return InvalidArgument(
2345 "Dynamic slice index count does not match argument rank: %u vs %d.",
2346 slice_sizes.size(), operand_shape.rank());
2347 }
2348
2349 for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
2350 const int64 input_dim_size = operand_shape.dimensions(dim);
2351 const int64 slice_dim_size = slice_sizes[dim];
2352 if (slice_dim_size < 0) {
2353 return InvalidArgument("Negative size index to dynamic slice: %d.",
2354 slice_dim_size);
2355 }
2356 if (slice_dim_size > input_dim_size) {
2357 return InvalidArgument(
2358 "Slice dim size %d greater than dynamic slice dimension: %d.",
2359 slice_dim_size, input_dim_size);
2360 }
2361 VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2362 }
2363
2364 return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2365 }
2366
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2367 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2368 const Shape& operand_shape, const Shape& update_shape,
2369 absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2370 TF_RETURN_IF_ERROR(
2371 ExpectArray(operand_shape, "operand of dynamic update slice"));
2372 TF_RETURN_IF_ERROR(
2373 ExpectArray(update_shape, "update of dynamic update slice"));
2374
2375 auto number_of_indices = start_index_shapes.size();
2376 // TODO(b/118437727): Remove this path.
2377 if (!allow_scalar_indices ||
2378 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2379 if (number_of_indices != 1) {
2380 return InvalidArgument(
2381 "Dynamic update slice should have exactly 1 index operand, has %d.",
2382 number_of_indices);
2383 }
2384 const Shape& start_indices_shape = start_index_shapes[0];
2385 TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2386 "start indices of dynamic update slice"));
2387
2388 VLOG(2) << StrFormat(
2389 "updating slice of shape %s at dynamic start_indices %s with update "
2390 "shape %s",
2391 ShapeUtil::HumanString(operand_shape),
2392 ShapeUtil::HumanString(start_indices_shape),
2393 ShapeUtil::HumanString(update_shape));
2394
2395 if (start_indices_shape.rank() != 1) {
2396 return InvalidArgument(
2397 "Dynamic update slice start indices of rank %d must be rank1.",
2398 start_indices_shape.rank());
2399 }
2400
2401 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2402 return InvalidArgument(
2403 "Dynamic update slice start indices must be of integral type.");
2404 }
2405
2406 const int64 start_num_dims = start_indices_shape.dimensions(0);
2407 if (operand_shape.rank() != start_num_dims) {
2408 return InvalidArgument(
2409 "Dynamic update slice start number of dimensions %d (%s) must match "
2410 "rank %d of slice input (%s).",
2411 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2412 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2413 }
2414 } else {
2415 VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2416 ShapeUtil::HumanString(operand_shape),
2417 ShapeUtil::HumanString(update_shape));
2418
2419 if (operand_shape.rank() != number_of_indices) {
2420 return InvalidArgument(
2421 "Dynamic update slice start number of dimensions %d must match "
2422 "rank %d of slice input (%s).",
2423 number_of_indices, operand_shape.rank(),
2424 ShapeUtil::HumanString(operand_shape));
2425 }
2426
2427 if (number_of_indices > 0) {
2428 const Shape& first_index_shape = start_index_shapes[0];
2429 if (!ShapeUtil::IsScalar(first_index_shape)) {
2430 return InvalidArgument(
2431 "Dynamic update slice indices must be scalar, not %s.",
2432 ShapeUtil::HumanString(first_index_shape));
2433 }
2434 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2435 return InvalidArgument(
2436 "Dynamic update slice start indices must be of integral type.");
2437 }
2438 for (const Shape& index_shape : start_index_shapes) {
2439 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2440 return InvalidArgument(
2441 "Dynamic update slice start indices must all have the same "
2442 "shape, got mismatching indices with shapes %s and %s.",
2443 ShapeUtil::HumanString(first_index_shape),
2444 ShapeUtil::HumanString(index_shape));
2445 }
2446 }
2447 }
2448 }
2449
2450 if (update_shape.rank() != operand_shape.rank()) {
2451 return InvalidArgument(
2452 "Dynamic update slice update rank does not match argument rank: "
2453 "%d vs %d.",
2454 update_shape.rank(), operand_shape.rank());
2455 }
2456
2457 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2458 update_shape)) {
2459 return InvalidArgument(
2460 "Dynamic update slice update element type does not match argument. "
2461 "operand.element_type: %s vs update.element_type: %s.",
2462 PrimitiveType_Name(operand_shape.element_type()),
2463 PrimitiveType_Name(update_shape.element_type()));
2464 }
2465
2466 for (int64 dim = 0; dim < operand_shape.rank(); ++dim) {
2467 const int64 input_dim_size = operand_shape.dimensions(dim);
2468 const int64 update_dim_size = update_shape.dimensions(dim);
2469 if (update_dim_size < 0) {
2470 return InvalidArgument(
2471 "Size index %d to dynamic update slice must be >= 0.",
2472 update_dim_size);
2473 }
2474 if (update_dim_size > input_dim_size) {
2475 return InvalidArgument(
2476 "Update dim size %d greater than dynamic slice dimension: %d.",
2477 update_dim_size, input_dim_size);
2478 }
2479 VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2480 }
2481
2482 return operand_shape;
2483 }
2484
InferReverseShape(const Shape & operand_shape,absl::Span<const int64> dimensions)2485 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2486 const Shape& operand_shape, absl::Span<const int64> dimensions) {
2487 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2488 if (!AllUnique(dimensions)) {
2489 return InvalidArgument("a dimension number is duplicated in reverse");
2490 }
2491 for (int64 dimension : dimensions) {
2492 if (dimension >= operand_shape.rank() || dimension < 0) {
2493 return InvalidArgument(
2494 "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2495 dimension, ShapeUtil::HumanString(operand_shape));
2496 }
2497 }
2498 return operand_shape;
2499 }
2500
InferGetTupleElementShape(const Shape & arg,int64 index)2501 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2502 const Shape& arg, int64 index) {
2503 if (!arg.IsTuple()) {
2504 return InvalidArgument(
2505 "Cannot infer shape: attempting to index into non-tuple: %s.",
2506 ShapeUtil::HumanString(arg));
2507 }
2508
2509 if (index < 0 || index >= arg.tuple_shapes_size()) {
2510 return InvalidArgument(
2511 "Cannot infer shape: attempt to index out of tuple bounds: %d "
2512 ">= %d in shape %s.",
2513 index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2514 }
2515
2516 return arg.tuple_shapes(index);
2517 }
2518
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2519 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2520 const ProgramShape& condition, const ProgramShape& body,
2521 const Shape& init) {
2522 // Check the number of parameters for given computations.
2523 if (condition.parameters_size() != 1) {
2524 return InvalidArgument("Condition must take 1 arguments; got %d.",
2525 condition.parameters_size());
2526 }
2527 if (body.parameters_size() != 1) {
2528 return InvalidArgument("Body must take 1 arguments; got %d.",
2529 body.parameters_size());
2530 }
2531
2532 auto shape_string = [&]() {
2533 return StrFormat(
2534 "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2535 ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2536 };
2537
2538 // Check the shapes of computation parameters and return types.
2539 if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) {
2540 return InvalidArgument("Condition must return a boolean; got %s.",
2541 shape_string());
2542 }
2543 if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2544 !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2545 !ShapeUtil::Compatible(body.result(), init)) {
2546 return InvalidArgument(
2547 "The parameter of condition and body, the result of the body, and init "
2548 "must all have the same shape; got %s.",
2549 shape_string());
2550 }
2551
2552 return init;
2553 }
2554
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2555 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2556 const Shape& branch_index,
2557 absl::Span<const ProgramShape> branch_computations,
2558 absl::Span<const Shape> branch_operands) {
2559 if (!ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2560 !ShapeUtil::Equal(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2561 return InvalidArgument("branch_index must be bool or int32; got %s.",
2562 ShapeUtil::HumanString(branch_index));
2563 }
2564 if (branch_index.element_type() == PRED) {
2565 TF_RET_CHECK(2 == branch_computations.size());
2566 } else {
2567 TF_RET_CHECK(!branch_computations.empty());
2568 }
2569 TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2570
2571 for (int j = 0; j < branch_computations.size(); ++j) {
2572 if (branch_computations[j].parameters_size() != 1) {
2573 return InvalidArgument(
2574 "branch computation %d must take 1 argument; got %d.", j,
2575 branch_computations[j].parameters_size());
2576 }
2577 if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2578 branch_operands[j])) {
2579 auto shape_string = [&]() {
2580 return StrFormat("operand: %s; computation: %s",
2581 ShapeUtil::HumanString(branch_operands[j]),
2582 ShapeUtil::HumanString(branch_computations[j]));
2583 };
2584 return InvalidArgument(
2585 "branch operand %d must match the shape of the only parameter of "
2586 "branch computation %d: got %s.",
2587 j, j, shape_string());
2588 }
2589
2590 if (!ShapeUtil::Compatible(branch_computations[0].result(),
2591 branch_computations[j].result())) {
2592 auto shape_string = [&]() {
2593 return StrFormat(
2594 "branch 0 computation result: %s; branch %d computation result: %s",
2595 ShapeUtil::HumanString(branch_computations[0].result()), j,
2596 ShapeUtil::HumanString(branch_computations[j].result()));
2597 };
2598 return InvalidArgument(
2599 "the result of branch 0 computation and branch %d computation must "
2600 "have the same shape: got %s.",
2601 j, shape_string());
2602 }
2603 }
2604 return branch_computations[0].result();
2605 }
2606
InferBroadcastShape(const Shape & operand,absl::Span<const int64> broadcast_sizes)2607 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2608 const Shape& operand, absl::Span<const int64> broadcast_sizes) {
2609 TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2610 for (int64 size : broadcast_sizes) {
2611 if (size < 0) {
2612 return InvalidArgument("Broadcast with negative dimension size %d.",
2613 size);
2614 }
2615 }
2616
2617 std::vector<int64> dimensions(operand.dimensions_size() +
2618 broadcast_sizes.size());
2619 std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2620 std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2621 dimensions.begin() + broadcast_sizes.size());
2622 return ShapeUtil::MakeShape(operand.element_type(), dimensions);
2623 }
2624
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64> broadcast_dimensions)2625 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2626 const Shape& operand_shape, const Shape& output_shape,
2627 absl::Span<const int64> broadcast_dimensions) {
2628 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
2629 TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
2630 const int64 operand_rank = operand_shape.rank();
2631 const int64 output_rank = output_shape.rank();
2632 if (operand_rank > output_rank) {
2633 return InvalidArgument(
2634 "InDim style broadcast must be to an equal or higher ranked shape; "
2635 "operand rank: %lld; output rank: %lld",
2636 operand_rank, output_rank);
2637 }
2638 if (operand_rank != broadcast_dimensions.size()) {
2639 return InvalidArgument(
2640 "Size of broadcast_dimensions has to match operand's rank; operand "
2641 "rank: %lld, size of broadcast_dimensions %u.",
2642 operand_rank, broadcast_dimensions.size());
2643 }
2644 for (int64 i = 0; i < operand_rank; i++) {
2645 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
2646 return InvalidArgument("Broadcast dimension %lld is out of bound",
2647 broadcast_dimensions[i]);
2648 }
2649 if (operand_shape.dimensions(i) !=
2650 output_shape.dimensions(broadcast_dimensions[i]) &&
2651 operand_shape.dimensions(i) != 1) {
2652 return InvalidArgument(
2653 "Input dimension should be either 1 or equal to the output dimension "
2654 "it is broadcasting into; the %lldth operand dimension is %lld, the "
2655 "%lldth output dimension is %lld.",
2656 i, operand_shape.dimensions(i), broadcast_dimensions[i],
2657 output_shape.dimensions(broadcast_dimensions[i]));
2658 }
2659 if (operand_shape.is_dynamic_dimension(i) !=
2660 output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
2661 return InvalidArgument(
2662 "Broadcast input and output dynamism mismatch: %s and %s",
2663 operand_shape.ToString(), output_shape.ToString());
2664 }
2665 // Make sure the broadcast dimensions are listed in a strictly increasing
2666 // order.
2667 if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
2668 return InvalidArgument(
2669 "Broadcast dimensions order is wrong: %d comes after %d.",
2670 broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
2671 }
2672 }
2673
2674 return output_shape;
2675 }
2676
InferReshapeShape(const Shape & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)2677 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
2678 const Shape& operand, absl::Span<const int64> dimensions,
2679 absl::Span<const int64> new_sizes) {
2680 TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
2681
2682 Shape inferred_shape =
2683 ShapeUtil::MakeShape(operand.element_type(), new_sizes);
2684 VLOG(3) << "Reshape inferred shape: "
2685 << ShapeUtil::HumanString(inferred_shape);
2686
2687 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
2688 return InvalidArgument(
2689 "Reshape operation has mismatched element counts: from=%d (%s) "
2690 "to=%d (%s).",
2691 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
2692 ShapeUtil::ElementsIn(inferred_shape),
2693 ShapeUtil::HumanString(inferred_shape));
2694 }
2695
2696 std::vector<int64> indices(operand.rank());
2697 std::iota(indices.begin(), indices.end(), 0);
2698 if (dimensions.size() != operand.rank() ||
2699 !std::is_permutation(dimensions.begin(), dimensions.end(),
2700 indices.begin())) {
2701 return InvalidArgument(
2702 "Reshape dimensions [%s] are not a permutation of the operand "
2703 "dimensions (operand shape is %s).",
2704 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2705 }
2706
2707 std::vector<std::pair<int64, int64>> unmodified_dims =
2708 ShapeUtil::DimensionsUnmodifiedByReshape(operand, inferred_shape);
2709 for (auto& unmodified : unmodified_dims) {
2710 if (operand.is_dynamic_dimension(unmodified.first)) {
2711 inferred_shape.set_dynamic_dimension(unmodified.second, true);
2712 }
2713 }
2714
2715 return inferred_shape;
2716 }
2717
InferTransposeShape(const Shape & operand,absl::Span<const int64> dimensions)2718 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
2719 const Shape& operand, absl::Span<const int64> dimensions) {
2720 TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
2721
2722 if (!IsPermutation(dimensions, operand.rank())) {
2723 return InvalidArgument(
2724 "Transpose dimensions [%s] are not a permutation of the operand "
2725 "dimensions (operand shape is %s).",
2726 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
2727 }
2728
2729 // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
2730 // we need output[i]=input[dimensions[i]] which is
2731 // Permute(Inverse(dimensions),input).
2732 return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
2733 }
2734
2735 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2736 // "degenerate" cases, as with binary elementwise ops.
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)2737 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
2738 const Shape& min, const Shape& operand, const Shape& max) {
2739 TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
2740 TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
2741 TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
2742 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
2743 !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
2744 return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
2745 ShapeUtil::HumanString(min),
2746 ShapeUtil::HumanString(operand),
2747 ShapeUtil::HumanString(max));
2748 }
2749 if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
2750 ShapeUtil::IsScalar(min)) &&
2751 (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
2752 ShapeUtil::IsScalar(max)))) {
2753 return operand;
2754 }
2755 if (ShapeUtil::IsScalar(operand)) {
2756 if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
2757 return ShapeUtil::ChangeElementType(min, operand.element_type());
2758 } else if (ShapeUtil::IsScalar(min)) {
2759 return ShapeUtil::ChangeElementType(max, operand.element_type());
2760 } else if (ShapeUtil::IsScalar(max)) {
2761 return ShapeUtil::ChangeElementType(min, operand.element_type());
2762 }
2763 }
2764 return Unimplemented("%s, %s <clamp> %s is not implemented.",
2765 min.ShortDebugString(), max.ShortDebugString(),
2766 operand.ShortDebugString());
2767 }
2768
2769 // TODO(b/36794510): Make broadcast semantics more consistent, by supporting
2770 // "degenerate" cases, as with binary elementwise ops, as well as scalar
2771 // broadcast from all operands, not just the predicate.
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2772 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
2773 const Shape& pred, const Shape& on_true, const Shape& on_false) {
2774 if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
2775 return InvalidArgument(
2776 "Operands to select must be the same shape; got %s and %s.",
2777 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
2778 }
2779 if (pred.element_type() != PRED) {
2780 return InvalidArgument(
2781 "Select's pred operand must have PRED element type; got %s.",
2782 ShapeUtil::HumanString(pred));
2783 }
2784 if (Shape::Equal()
2785 .IgnoreElementType()
2786 .IgnoreLayout()
2787 .IgnoreDynamicDimension()(pred, on_true) ||
2788 ShapeUtil::IsScalar(pred)) {
2789 // By this stage we know that pred's element type is PRED. Therefore, this
2790 // check restricts pred to be a PRED scalar, or a PRED array with the same
2791 // dimensions as on_true and on_false.
2792 Shape inferred_shape = ShapeUtil::ChangeElementType(
2793 on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
2794
2795 // Propagate dynamic dimensions if pred is not a scalar.
2796 if (!ShapeUtil::IsScalar(pred)) {
2797 for (int i = 0; i < inferred_shape.rank(); i++) {
2798 if (pred.is_dynamic_dimension(i)) {
2799 inferred_shape.set_dynamic_dimension(i, true);
2800 }
2801 }
2802 }
2803 return inferred_shape;
2804 }
2805 return InvalidArgument(
2806 "Select operation with non-scalar predicate with dimensionality "
2807 "different from the other operands: %s.",
2808 ShapeUtil::HumanString(pred));
2809 }
2810
InferTupleSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)2811 /* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
2812 const Shape& pred, const Shape& on_true, const Shape& on_false) {
2813 // Select only defines the top-level buffer, so if it's a tuple, the two
2814 // input must match exactly.
2815 if (!ShapeUtil::Compatible(on_true, on_false)) {
2816 return InvalidArgument(
2817 "Operands to tuple-select must be the same shape; got %s and %s.",
2818 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
2819 }
2820 if (pred.element_type() != PRED) {
2821 return InvalidArgument(
2822 "TupleSelect's pred operand must have PRED element type; got %s.",
2823 ShapeUtil::HumanString(pred));
2824 }
2825 if (!ShapeUtil::IsScalar(pred)) {
2826 return InvalidArgument(
2827 "TupleSelect operation with non-scalar predicate: %s.",
2828 ShapeUtil::HumanString(pred));
2829 }
2830 return on_true;
2831 }
2832
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)2833 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
2834 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
2835 // The applied function's arity equals the number of arguments.
2836 if (arg_shapes.size() != to_apply.parameters_size()) {
2837 string computation_signature = ShapeUtil::HumanString(to_apply);
2838 string argument_shapes =
2839 StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) {
2840 absl::StrAppend(out, ShapeUtil::HumanString(*shape));
2841 });
2842 return InvalidArgument(
2843 "Call applied function arity must match number of arguments; got: "
2844 "arity: %d, arguments: %u; computation signature: %s; argument "
2845 "shapes: [%s].",
2846 to_apply.parameters_size(), arg_shapes.size(), computation_signature,
2847 argument_shapes);
2848 }
2849
2850 // All arguments must be compatible with the program shape.
2851 for (int i = 0; i < arg_shapes.size(); ++i) {
2852 const Shape& arg_shape = *arg_shapes[i];
2853 const Shape& param_shape = to_apply.parameters(i);
2854 if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
2855 return InvalidArgument(
2856 "Call parameter must match argument; got parameter %d shape: %s, "
2857 "argument shape: %s.",
2858 i, ShapeUtil::HumanString(param_shape),
2859 ShapeUtil::HumanString(arg_shape));
2860 }
2861 }
2862
2863 return to_apply.result();
2864 }
2865
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64> start_indices_shape,const GatherDimensionNumbers & dim_numbers)2866 static Status ValidateGatherDimensionNumbers(
2867 const Shape& input_shape, absl::Span<const int64> start_indices_shape,
2868 const GatherDimensionNumbers& dim_numbers) {
2869 if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
2870 return InvalidArgument(
2871 "Output window dimensions in gather op must be ascending; got: %s.",
2872 StrJoin(dim_numbers.offset_dims(), ", "));
2873 }
2874
2875 if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
2876 dim_numbers.offset_dims().end()) {
2877 return InvalidArgument(
2878 "Output window dimensions in gather op must not repeat; got: %s.",
2879 StrJoin(dim_numbers.offset_dims(), ", "));
2880 }
2881
2882 const int64 output_offset_dim_count = dim_numbers.offset_dims_size();
2883 const int64 output_shape_rank =
2884 output_offset_dim_count + start_indices_shape.size() - 1;
2885
2886 for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
2887 int64 offset_dim = dim_numbers.offset_dims(i);
2888 if (offset_dim < 0 || offset_dim >= output_shape_rank) {
2889 return InvalidArgument(
2890 "Offset dimension %d in gather op is out of bounds; got %d, but "
2891 "should "
2892 "have been in [0,%d).",
2893 i, offset_dim, output_shape_rank);
2894 }
2895 }
2896
2897 if (dim_numbers.start_index_map_size() !=
2898 start_indices_shape[dim_numbers.index_vector_dim()]) {
2899 return InvalidArgument(
2900 "Gather op has %d elements in start_index_map and the "
2901 "bound of dimension index_vector_dim=%d of start_indices is "
2902 "%d. These two numbers must be equal.",
2903 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
2904 start_indices_shape[dim_numbers.index_vector_dim()]);
2905 }
2906
2907 for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
2908 int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
2909 if (operand_dim_for_start_index_i < 0 ||
2910 operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
2911 return InvalidArgument(
2912 "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
2913 input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
2914 }
2915 }
2916
2917 std::vector<int64> sorted_start_index_map(
2918 dim_numbers.start_index_map().begin(),
2919 dim_numbers.start_index_map().end());
2920
2921 absl::c_sort(sorted_start_index_map);
2922
2923 if (absl::c_adjacent_find(sorted_start_index_map) !=
2924 sorted_start_index_map.end()) {
2925 return InvalidArgument(
2926 "Repeated dimensions are not allowed in start_index_map; "
2927 "got: %s.",
2928 StrJoin(dim_numbers.start_index_map(), ", "));
2929 }
2930
2931 for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) {
2932 if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
2933 return InvalidArgument(
2934 "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
2935 "%d), got: %d.",
2936 input_shape.dimensions_size(), collapsed_dim);
2937 }
2938 }
2939
2940 if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
2941 return InvalidArgument(
2942 "collapsed_slice_dims in gather op must be sorted; got: %s",
2943 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
2944 }
2945
2946 if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
2947 dim_numbers.collapsed_slice_dims().end()) {
2948 return InvalidArgument(
2949 "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
2950 "got: %s.",
2951 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
2952 }
2953
2954 return Status::OK();
2955 }
2956
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)2957 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
2958 const Shape& input_shape, const Shape& start_indices_shape,
2959 const GatherDimensionNumbers& gather_dim_numbers,
2960 absl::Span<const int64> slice_sizes) {
2961 TF_RETURN_IF_ERROR(
2962 ExpectArray(input_shape, "input tensor operand gather op"));
2963 TF_RETURN_IF_ERROR(
2964 ExpectArray(start_indices_shape, "gather indices operand of gather op"));
2965
2966 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2967 return InvalidArgument(
2968 "Gather indices parameter must be an integral tensor; got %s.",
2969 ShapeUtil::HumanString(start_indices_shape));
2970 }
2971
2972 // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
2973 // index_vector_dim is rank(P). The bounds of this expanded shape is
2974 // stored in expanded_start_indices_shape.
2975
2976 if (start_indices_shape.dimensions_size() <
2977 gather_dim_numbers.index_vector_dim() ||
2978 gather_dim_numbers.index_vector_dim() < 0) {
2979 return InvalidArgument(
2980 "Gather index leaf dimension must be within [0, rank(start_indices) + "
2981 "1). rank(start_indices) is %d and gather index leaf dimension is "
2982 "%d.",
2983 start_indices_shape.dimensions_size(),
2984 gather_dim_numbers.index_vector_dim());
2985 }
2986
2987 std::vector<int64> expanded_start_indices_shape;
2988 expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
2989 absl::c_copy(start_indices_shape.dimensions(),
2990 std::back_inserter(expanded_start_indices_shape));
2991 if (expanded_start_indices_shape.size() ==
2992 gather_dim_numbers.index_vector_dim()) {
2993 expanded_start_indices_shape.push_back(1);
2994 }
2995
2996 TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
2997 input_shape, expanded_start_indices_shape, gather_dim_numbers));
2998
2999 if (slice_sizes.size() != input_shape.dimensions_size()) {
3000 return InvalidArgument(
3001 "Gather op must have one slice size for every input dimension; got: "
3002 "len(slice_sizes)=%lu, input_shape.rank=%d.",
3003 slice_sizes.size(), input_shape.dimensions_size());
3004 }
3005
3006 if (slice_sizes.size() !=
3007 gather_dim_numbers.offset_dims_size() +
3008 gather_dim_numbers.collapsed_slice_dims_size()) {
3009 return InvalidArgument(
3010 "All components of the offset index in a gather op must either be a "
3011 "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3012 "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3013 slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3014 StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3015 }
3016
3017 for (int i = 0; i < slice_sizes.size(); i++) {
3018 int64 slice_size = slice_sizes[i];
3019 int64 corresponding_input_size = input_shape.dimensions(i);
3020 if (slice_size < 0 || slice_size > corresponding_input_size) {
3021 return InvalidArgument(
3022 "Slice size at index %d in gather op is out of range, must be "
3023 "within [0, %d), got %d.",
3024 i, corresponding_input_size + 1, slice_size);
3025 }
3026 }
3027
3028 for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3029 if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) {
3030 return InvalidArgument(
3031 "Gather op can only collapse slice dims with bound 1, but bound is "
3032 "%d for index %d at position %d.",
3033 slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3034 gather_dim_numbers.collapsed_slice_dims(i), i);
3035 }
3036 }
3037
3038 int64 result_rank = gather_dim_numbers.offset_dims_size() +
3039 (expanded_start_indices_shape.size() - 1);
3040 int64 offset_dims_seen = 0;
3041 int64 gather_dims_seen = 0;
3042 std::vector<int64> output_dim_bounds;
3043 output_dim_bounds.reserve(result_rank);
3044 for (int64 i = 0; i < result_rank; i++) {
3045 int64 current_bound;
3046 bool is_window_index =
3047 absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3048 if (is_window_index) {
3049 while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3050 offset_dims_seen)) {
3051 offset_dims_seen++;
3052 }
3053 current_bound = slice_sizes[offset_dims_seen++];
3054 } else {
3055 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3056 gather_dims_seen++;
3057 }
3058 current_bound = expanded_start_indices_shape[gather_dims_seen++];
3059 }
3060
3061 output_dim_bounds.push_back(current_bound);
3062 }
3063
3064 return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
3065 }
3066
3067 namespace {
3068
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3069 Status ValidateScatterDimensionNumbers(
3070 const Shape& operand_shape, absl::Span<const int64> scatter_indices_shape,
3071 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3072 // Validate update_window_dims in ScatterDimensionNumbers.
3073 if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3074 return InvalidArgument(
3075 "update_window_dims in scatter op must be sorted; got: %s.",
3076 StrJoin(dim_numbers.update_window_dims(), ", "));
3077 }
3078 if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3079 dim_numbers.update_window_dims().end()) {
3080 return InvalidArgument(
3081 "update_window_dims in scatter op must not repeat; got: %s.",
3082 StrJoin(dim_numbers.update_window_dims(), ", "));
3083 }
3084 const int64 updates_rank = updates_shape.rank();
3085 for (int64 window_dim : dim_numbers.update_window_dims()) {
3086 if (window_dim < 0 || window_dim >= updates_rank) {
3087 return InvalidArgument(
3088 "Invalid update_window_dims set in scatter op; valid range is [0, "
3089 "%d). got: %d.",
3090 updates_rank, window_dim);
3091 }
3092 }
3093
3094 // Validate inserted_window_dims in ScatterDimensionNumbers.
3095 if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3096 return InvalidArgument(
3097 "inserted_window_dims in scatter op must be sorted; got: %s.",
3098 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3099 }
3100 if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3101 dim_numbers.inserted_window_dims().end()) {
3102 return InvalidArgument(
3103 "inserted_window_dims in scatter op must not repeat; got: %s.",
3104 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3105 }
3106 for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
3107 if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3108 return InvalidArgument(
3109 "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3110 "%d), got: %d.",
3111 operand_shape.dimensions_size(), inserted_dim);
3112 }
3113 }
3114
3115 // Validate window size.
3116 auto window_size = dim_numbers.update_window_dims_size() +
3117 dim_numbers.inserted_window_dims_size();
3118 if (window_size != operand_shape.rank()) {
3119 return InvalidArgument(
3120 "Scatter op has window of size %d; doesn't match operand of rank %d.",
3121 window_size, operand_shape.rank());
3122 }
3123
3124 // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3125 if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3126 scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3127 return InvalidArgument(
3128 "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3129 "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3130 "These two numbers must be equal.",
3131 dim_numbers.scatter_dims_to_operand_dims_size(),
3132 dim_numbers.index_vector_dim(),
3133 scatter_indices_shape[dim_numbers.index_vector_dim()]);
3134 }
3135 for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3136 int64 scatter_dim_to_operand_dim =
3137 dim_numbers.scatter_dims_to_operand_dims(i);
3138 if (scatter_dim_to_operand_dim < 0 ||
3139 scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3140 return InvalidArgument(
3141 "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3142 "got: %d->%d.",
3143 operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3144 }
3145 }
3146 std::vector<int64> sorted_scatter_dims_to_operand_dims(
3147 dim_numbers.scatter_dims_to_operand_dims().begin(),
3148 dim_numbers.scatter_dims_to_operand_dims().end());
3149 absl::c_sort(sorted_scatter_dims_to_operand_dims);
3150 if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3151 sorted_scatter_dims_to_operand_dims.end()) {
3152 return InvalidArgument(
3153 "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3154 "got: %s.",
3155 StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3156 }
3157
3158 return Status::OK();
3159 }
3160
3161 } // namespace
3162
InferScatterShape(const Shape & operand_shape,const Shape & scatter_indices_shape,const Shape & updates_shape,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3163 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3164 const Shape& operand_shape, const Shape& scatter_indices_shape,
3165 const Shape& updates_shape, const ProgramShape& to_apply_shape,
3166 const ScatterDimensionNumbers& scatter_dim_numbers) {
3167 TF_RETURN_IF_ERROR(
3168 ExpectArray(operand_shape, "operand tensor of scatter op"));
3169 TF_RETURN_IF_ERROR(
3170 ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3171 TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
3172
3173 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3174 return InvalidArgument(
3175 "Scatter indices parameter must be an integral tensor; got %s.",
3176 ShapeUtil::HumanString(scatter_indices_shape));
3177 }
3178
3179 if (scatter_indices_shape.dimensions_size() <
3180 scatter_dim_numbers.index_vector_dim() ||
3181 scatter_dim_numbers.index_vector_dim() < 0) {
3182 return InvalidArgument(
3183 "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3184 " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3185 "is %d.",
3186 scatter_indices_shape.dimensions_size(),
3187 scatter_dim_numbers.index_vector_dim());
3188 }
3189
3190 // Check if the update computation has a proper shape as a reduction.
3191 const Shape init_value_shape =
3192 ShapeUtil::MakeShape(operand_shape.element_type(), {});
3193 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
3194 {updates_shape.element_type()},
3195 /*inputs=*/1));
3196
3197 std::vector<int64> expanded_scatter_indices_shape =
3198 ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));
3199 if (expanded_scatter_indices_shape.size() ==
3200 scatter_dim_numbers.index_vector_dim()) {
3201 expanded_scatter_indices_shape.push_back(1);
3202 }
3203
3204 int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
3205 scatter_dim_numbers.update_window_dims_size();
3206 if (updates_shape.rank() != expected_updates_rank) {
3207 return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3208 expected_updates_rank, updates_shape.rank());
3209 }
3210
3211 TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3212 operand_shape, expanded_scatter_indices_shape, updates_shape,
3213 scatter_dim_numbers));
3214
3215 int64 inserted_dims_seen = 0;
3216 std::vector<int64> max_update_slice_sizes;
3217 for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
3218 if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
3219 scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3220 ++inserted_dims_seen;
3221 } else {
3222 max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3223 }
3224 }
3225 for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3226 auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3227 if (updates_shape.dimensions(update_window_dim) >
3228 max_update_slice_sizes[i]) {
3229 return InvalidArgument(
3230 "Bounds of the window dimensions of updates must not exceed the "
3231 "bounds of the corresponding dimensions of operand. For dimension "
3232 "%d, updates bound is %d, operand bound is %d.",
3233 update_window_dim, updates_shape.dimensions(update_window_dim),
3234 max_update_slice_sizes[i]);
3235 }
3236 }
3237
3238 int64 scatter_dims_seen = 0;
3239 for (int64 i = 0; i < updates_shape.rank(); ++i) {
3240 bool is_update_window_dim =
3241 absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3242 if (is_update_window_dim) {
3243 continue;
3244 }
3245 if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3246 ++scatter_dims_seen;
3247 }
3248 if (updates_shape.dimensions(i) !=
3249 expanded_scatter_indices_shape[scatter_dims_seen]) {
3250 return InvalidArgument(
3251 "Bounds of the scatter dimensions of updates must be same as the "
3252 "bounds of the corresponding dimensions of scatter indices. For "
3253 "scatter dimension %d, updates bound is %d, scatter_indices "
3254 "bound is %d.",
3255 i, updates_shape.dimensions(i),
3256 expanded_scatter_indices_shape[scatter_dims_seen]);
3257 }
3258 ++scatter_dims_seen;
3259 }
3260
3261 return operand_shape;
3262 }
3263
3264 } // namespace xla
3265