1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <numeric>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/compiler/xla/permutation_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/math/math_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/statusor.h"
43
44 namespace xla {
45 namespace {
46
47 using absl::StrFormat;
48 using absl::StrJoin;
49
50 // Returns true if no element is present in slice more than once.
AllUnique(absl::Span<const int64_t> slice)51 bool AllUnique(absl::Span<const int64_t> slice) {
52 return std::set<int64_t>(slice.begin(), slice.end()).size() == slice.size();
53 }
54
ExpectArray(const Shape & shape,absl::string_view op_type)55 Status ExpectArray(const Shape& shape, absl::string_view op_type) {
56 if (!shape.IsArray()) {
57 return InvalidArgument("Expected array argument for %s, but got %s.",
58 std::string(op_type), ShapeUtil::HumanString(shape));
59 }
60 return OkStatus();
61 }
62
VerifyReducerShape(const ProgramShape & reducer_shape,absl::Span<const Shape * const> init_value_shapes,absl::Span<const PrimitiveType> input_element_types,int64_t inputs)63 Status VerifyReducerShape(const ProgramShape& reducer_shape,
64 absl::Span<const Shape* const> init_value_shapes,
65 absl::Span<const PrimitiveType> input_element_types,
66 int64_t inputs) {
67 if (reducer_shape.parameters_size() != inputs * 2) {
68 return InvalidArgument(
69 "Reduction function must take %d parameters, but "
70 "takes %d parameter(s).",
71 inputs * 2, reducer_shape.parameters_size());
72 }
73
74 const Shape& accumulator_shape = reducer_shape.result();
75 std::vector<const Shape*> accumulator_subshapes;
76 if (accumulator_shape.IsArray()) {
77 if (inputs != 1) {
78 return InvalidArgument(
79 "Reduction function must produce a tuple with %d elements, but "
80 "produces a scalar",
81 inputs);
82 }
83 accumulator_subshapes.push_back(&accumulator_shape);
84 } else if (accumulator_shape.IsTuple()) {
85 if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
86 return InvalidArgument(
87 "Reduction function must produce a tuple with %d elements, but has "
88 "%d elements",
89 inputs, ShapeUtil::TupleElementCount(accumulator_shape));
90 }
91 for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
92 accumulator_subshapes.push_back(&element_shape);
93 }
94 } else {
95 return InvalidArgument(
96 "Reduction function must produce a scalar or tuple of scalars, but has "
97 "shape: %s",
98 ShapeUtil::HumanString(accumulator_shape));
99 }
100
101 for (const Shape* element_shape : accumulator_subshapes) {
102 if (element_shape->rank() != 0) {
103 return InvalidArgument(
104 "Reduction function must return a scalar or tuple of scalars but "
105 "returns shape: %s",
106 ShapeUtil::HumanString(accumulator_shape));
107 }
108 }
109
110 for (int64_t i = 0; i < inputs; ++i) {
111 // Check that the accumulator can be passed in as the first argument.
112 // Note: comparing here and below with Compatible since we don't care about
113 // layout in scalars - see b/26668201 for a longer-term vision.
114 if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
115 reducer_shape.parameters(i))) {
116 return InvalidArgument(
117 "Reduction function's %d-th parameter shape differs from the "
118 "result shape: %s vs %s",
119 i, ShapeUtil::HumanString(reducer_shape.parameters(i)),
120 ShapeUtil::HumanString(*accumulator_subshapes[i]));
121 }
122 // Check that init_value's shapes are suitable for reducer_shape.
123 if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
124 *init_value_shapes[i])) {
125 return InvalidArgument(
126 "Reduction function's accumulator shape at index %d differs from "
127 "the init_value shape: %s vs %s",
128 i, ShapeUtil::HumanString(*accumulator_subshapes[i]),
129 ShapeUtil::HumanString(*init_value_shapes[i]));
130 }
131 // Check that the inputs can be passed in as the non-accumulator arguments.
132 const Shape input_element_shape =
133 ShapeUtil::MakeShape(input_element_types[i], {});
134 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
135 input_element_shape, reducer_shape.parameters(inputs + i))) {
136 return InvalidArgument(
137 "Reduction function's %d-th parameter shape differs from the "
138 "input type element type: %s vs %s",
139 inputs + i,
140 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
141 ShapeUtil::HumanString(input_element_shape));
142 }
143 // Check that the accumulator and inputs to the reducer function match.
144 // If the accumulator is scalar, it must have the same type as the inputs
145 // (up to fp precision). If it is a tuple, then the k-th element of the
146 // tuple must have the same type as the K-th input (again, up to fp
147 // precision.)
148 if (!ShapeUtil::CompatibleIgnoringFpPrecision(
149 *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
150 return InvalidArgument(
151 "Reduction function's %d-th parameter shape must "
152 "match the result shape, but got %s vs %s.",
153 inputs + i,
154 ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)),
155 ShapeUtil::HumanString(*accumulator_subshapes[i]));
156 }
157 }
158
159 return OkStatus();
160 }
161
InferWindowOutputShape(const Shape & base_shape,const Window & window,PrimitiveType element_type)162 StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
163 const Window& window,
164 PrimitiveType element_type) {
165 if (window.dimensions_size() != base_shape.rank()) {
166 return InvalidArgument(
167 "Window has dimension %d but base shape has dimension %d.",
168 window.dimensions_size(), base_shape.rank());
169 }
170
171 std::vector<int64_t> output_dimensions(window.dimensions_size());
172 std::vector<bool> output_is_dynamic(window.dimensions_size());
173 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
174 const auto& dim = window.dimensions(i);
175 if (dim.size() <= 0) {
176 return InvalidArgument("Window %s has a non-positive dimension.",
177 window.DebugString());
178 }
179 if (dim.stride() <= 0) {
180 return InvalidArgument("Window %s has a non-positive stride.",
181 window.DebugString());
182 }
183 if (dim.base_dilation() < 1) {
184 return InvalidArgument(
185 "Window %s has a non-positive base area dilation factor.",
186 window.DebugString());
187 }
188 if (dim.window_dilation() < 1) {
189 return InvalidArgument(
190 "Window %s has a non-positive window dilation factor.",
191 window.DebugString());
192 }
193
194 const int64_t dilated_base = window_util::DilatedBound(
195 ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
196 const int64_t padded_dilated_base =
197 dim.padding_low() + dilated_base + dim.padding_high();
198 const int64_t dilated_window =
199 window_util::DilatedBound(dim.size(), dim.window_dilation());
200
201 output_dimensions[i] = window_util::StridedBound(
202 padded_dilated_base, dilated_window, dim.stride());
203 output_is_dynamic[i] = base_shape.is_dynamic_dimension(i);
204 }
205
206 return ShapeUtil::MakeValidatedShape(element_type, output_dimensions,
207 output_is_dynamic);
208 }
209
MaybeUpcast(PrimitiveType from_type,std::optional<PrimitiveType> preferred_element_type)210 StatusOr<PrimitiveType> MaybeUpcast(
211 PrimitiveType from_type,
212 std::optional<PrimitiveType> preferred_element_type) {
213 if (!preferred_element_type.has_value() ||
214 *preferred_element_type == from_type) {
215 return from_type;
216 }
217 if (!primitive_util::IsFloatingPointType(from_type) &&
218 primitive_util::BitWidth(*preferred_element_type) <
219 primitive_util::BitWidth(from_type)) {
220 return InvalidArgument(
221 "`preferred_element_type` must not be narrower than the original "
222 "type.");
223 }
224 return *preferred_element_type;
225 }
226
227 } // namespace
228
InferUnaryOpShape(HloOpcode opcode,const HloInstruction * operand)229 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
230 HloOpcode opcode, const HloInstruction* operand) {
231 return InferUnaryOpShape(opcode, operand->shape());
232 }
233
InferUnaryOpShape(HloOpcode opcode,const Shape & shape)234 /* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
235 HloOpcode opcode, const Shape& shape) {
236 // There is no copy operation at the proto level, so handle copy explicitly.
237 // A domain shape is the same as the input one.
238 if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) {
239 return shape;
240 }
241
242 TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
243
244 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
245 switch (opcode) {
246 case HloOpcode::kFloor:
247 case HloOpcode::kCeil:
248 case HloOpcode::kRoundNearestAfz:
249 case HloOpcode::kRoundNearestEven:
250 if (!ShapeUtil::ElementIsFloating(shape)) {
251 return InvalidArgument(
252 "Expected element type in shape to be floating for %s operation; "
253 "got %s.",
254 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
255 }
256 return shape;
257 case HloOpcode::kCos:
258 case HloOpcode::kSin:
259 case HloOpcode::kExp:
260 case HloOpcode::kExpm1:
261 case HloOpcode::kLog:
262 case HloOpcode::kLog1p:
263 case HloOpcode::kLogistic:
264 case HloOpcode::kRsqrt:
265 case HloOpcode::kSqrt:
266 case HloOpcode::kCbrt:
267 case HloOpcode::kTanh:
268 if (!ShapeUtil::ElementIsFloating(shape) &&
269 !ShapeUtil::ElementIsComplex(shape)) {
270 return InvalidArgument(
271 "Expected element type in shape to be floating or complex for %s "
272 "operation; got %s.",
273 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
274 }
275 return shape;
276 case HloOpcode::kReal:
277 case HloOpcode::kImag:
278 if (ShapeUtil::ElementIsComplex(shape)) {
279 return ShapeUtil::ComplexComponentShape(shape);
280 } else if (ShapeUtil::ElementIsFloating(shape)) {
281 return shape;
282 } else {
283 return InvalidArgument(
284 "Expected element type in shape to be floating or complex for "
285 "%s operation; got %s.",
286 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
287 }
288 case HloOpcode::kAbs:
289 if (ShapeUtil::ElementIsComplex(shape)) {
290 return ShapeUtil::ChangeElementType(
291 shape, primitive_util::ComplexComponentType(shape.element_type()));
292 } else if (ShapeUtil::ElementIsSigned(shape)) {
293 return shape;
294 } else {
295 return InvalidArgument(
296 "Expected element type in shape to be floating or complex for "
297 "%s operation; got %s.",
298 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
299 }
300 case HloOpcode::kClz:
301 if (!ShapeUtil::ElementIsIntegral(shape)) {
302 return InvalidArgument(
303 "Expected an integral element type in argument to Clz "
304 "operation; got %s.",
305 PrimitiveType_Name(shape.element_type()));
306 }
307 return shape;
308 case HloOpcode::kNegate:
309 if (!ShapeUtil::ElementIsIntegral(shape) &&
310 !ShapeUtil::ElementIsFloating(shape) &&
311 !ShapeUtil::ElementIsComplex(shape)) {
312 return InvalidArgument(
313 "Expected element type in shape to be integral, floating or "
314 "complex for %s operation; got %s.",
315 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
316 }
317 return shape;
318 case HloOpcode::kPopulationCount:
319 if (!ShapeUtil::ElementIsIntegral(shape)) {
320 return InvalidArgument(
321 "Expected an integral element type in argument to PopulationCount "
322 "operation; got %s.",
323 PrimitiveType_Name(shape.element_type()));
324 }
325 return shape;
326 case HloOpcode::kSign:
327 if (!ShapeUtil::ElementIsSigned(shape) &&
328 !ShapeUtil::ElementIsComplex(shape)) {
329 return InvalidArgument(
330 "Expected element type in shape to be signed or complex for "
331 "%s operation; got %s.",
332 HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
333 }
334 return shape;
335
336 case HloOpcode::kNot:
337 if (shape.element_type() != PRED &&
338 !primitive_util::IsIntegralType(shape.element_type())) {
339 return InvalidArgument(
340 "Expected pred or an integral element type in argument to Not "
341 "operation; got %s.",
342 PrimitiveType_Name(shape.element_type()));
343 }
344 return shape;
345
346 case HloOpcode::kIsFinite:
347 if (!ShapeUtil::ElementIsFloating(shape)) {
348 return InvalidArgument(
349 "Expected element type in shape to be floating "
350 "point for IsFinite "
351 "operation; got %s.",
352 PrimitiveType_Name(shape.element_type()));
353 }
354 return ShapeUtil::ChangeElementType(shape, PRED);
355
356 default:
357 return InvalidArgument(
358 "Unknown operation for unary shape inference: \"%s\".",
359 HloOpcodeString(opcode));
360 }
361 }
362
InferConcatOpShape(absl::Span<const Shape * const> arg_shapes,const int64_t dimension)363 /* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
364 absl::Span<const Shape* const> arg_shapes, const int64_t dimension) {
365 if (arg_shapes.empty()) {
366 return InvalidArgument("Concatenate expects at least one argument.");
367 }
368 if (dimension < 0 || dimension >= arg_shapes[0]->rank()) {
369 return InvalidArgument("Concatenate dimension out of bounds: %d.",
370 dimension);
371 }
372 const Shape* arg_shape = nullptr;
373 PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
374 for (const Shape* shape : arg_shapes) {
375 TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
376 if (!arg_shape) {
377 arg_shape = shape;
378 element_type = arg_shape->element_type();
379 continue;
380 }
381 if (arg_shape->rank() != shape->rank()) {
382 return InvalidArgument(
383 "Cannot concatenate arrays with different ranks: %d (%s) vs %d "
384 "(%s).",
385 arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(),
386 ShapeUtil::HumanString(*shape));
387 }
388 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
389 return InvalidArgument(
390 "Cannot concatenate arrays with different element types: %s vs %s.",
391 PrimitiveType_Name(arg_shape->element_type()),
392 PrimitiveType_Name(shape->element_type()));
393 }
394 for (int64_t dimension_number = 0; dimension_number < arg_shape->rank();
395 ++dimension_number) {
396 if (arg_shape->dimensions(dimension_number) !=
397 shape->dimensions(dimension_number)) {
398 if (dimension_number == dimension) {
399 continue; // It's okay to differ in the dimension we're
400 // concatenating.
401 }
402 return InvalidArgument(
403 "Cannot concatenate arrays that differ in dimensions other than "
404 "the one being concatenated. Dimension %d in both shapes must be "
405 "equal: %s vs %s.",
406 dimension_number, ShapeUtil::HumanString(*arg_shape),
407 ShapeUtil::HumanString(*shape));
408 }
409 }
410 element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
411 }
412
413 std::vector<int64_t> new_dimensions(arg_shape->dimensions().begin(),
414 arg_shape->dimensions().end());
415 for (size_t i = 1; i < arg_shapes.size(); ++i) {
416 new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
417 }
418
419 Shape result = ShapeUtil::MakeShape(element_type, new_dimensions);
420
421 // Set dynamic dimensions if any input has dynamic dimension.
422 for (const Shape* shape : arg_shapes) {
423 for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
424 if (shape->is_dynamic_dimension(i)) {
425 result.set_dynamic_dimension(i, true);
426 }
427 }
428 }
429 return result;
430 }
431
InferConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)432 /* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
433 const Shape& operand_shape, PrimitiveType new_element_type) {
434 if (!operand_shape.IsArray() ||
435 !primitive_util::IsArrayType(new_element_type)) {
436 // Note: we may want to support tuple conversions via this operation in the
437 // future, by recursing into the tuple elements to check all sub-conversions
438 // are valid. For now we just reject them, though.
439 return InvalidArgument(
440 "Convert does not allow non-arrays, so cannot convert from %s to %s.",
441 ShapeUtil::HumanString(operand_shape),
442 PrimitiveType_Name(new_element_type));
443 }
444
445 return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
446 }
447
InferBitcastConvertShape(const Shape & operand_shape,PrimitiveType new_element_type)448 /* static */ StatusOr<Shape> ShapeInference::InferBitcastConvertShape(
449 const Shape& operand_shape, PrimitiveType new_element_type) {
450 auto old_element_type = operand_shape.element_type();
451 if (primitive_util::IsComplexType(old_element_type) !=
452 primitive_util::IsComplexType(new_element_type)) {
453 return InvalidArgument("Conversion between complex and real type %s => %s.",
454 ShapeUtil::HumanString(operand_shape),
455 PrimitiveType_Name(new_element_type));
456 }
457 if (!operand_shape.IsArray() ||
458 !primitive_util::IsArrayType(new_element_type)) {
459 // Note: we may want to support tuple conversions via this operation in the
460 // future, by recursing into the tuple elements to check all sub-conversions
461 // are valid. For now we just reject them, though.
462 return InvalidArgument(
463 "Cannot convert from or to tuple type; requested conversion: %s => %s.",
464 ShapeUtil::HumanString(operand_shape),
465 PrimitiveType_Name(new_element_type));
466 }
467
468 int input_bitwidth = primitive_util::BitWidth(old_element_type);
469 int output_bitwidth = primitive_util::BitWidth(new_element_type);
470 if (std::max(input_bitwidth, output_bitwidth) %
471 std::min(input_bitwidth, output_bitwidth) !=
472 0) {
473 return InvalidArgument(
474 "Cannot bitcast types with undivisible bit-widths: %s => %s.",
475 PrimitiveType_Name(old_element_type),
476 PrimitiveType_Name(new_element_type));
477 }
478 int ratio = std::max(output_bitwidth, input_bitwidth) /
479 std::min(output_bitwidth, input_bitwidth);
480
481 Shape new_shape = operand_shape;
482 new_shape.set_element_type(new_element_type);
483 if (input_bitwidth > output_bitwidth) {
484 ShapeUtil::AppendMinorDimension(ratio, &new_shape);
485 } else if (input_bitwidth < output_bitwidth) {
486 int last_dimension_idx = operand_shape.dimensions_size() - 1;
487 if (operand_shape.dimensions_size() < 1 ||
488 operand_shape.dimensions(last_dimension_idx) != ratio) {
489 return InvalidArgument(
490 "Last dimension of input shape=%d is not equal to ratio of "
491 "bit-widths=%d "
492 "for bitcast-convert from %s to %s",
493 operand_shape.dimensions(last_dimension_idx), ratio,
494 ShapeUtil::HumanString(operand_shape),
495 PrimitiveType_Name(new_element_type));
496 }
497 new_shape.DeleteDimension(last_dimension_idx);
498 }
499 return new_shape;
500 }
501
InferReducePrecisionShape(const Shape & operand_shape,const int exponent_bits,const int mantissa_bits)502 /* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
503 const Shape& operand_shape, const int exponent_bits,
504 const int mantissa_bits) {
505 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
506 return InvalidArgument(
507 "Expected element type in shape to be floating point for "
508 "ReducePrecision operation; got %s.",
509 PrimitiveType_Name(operand_shape.element_type()));
510 }
511 if (exponent_bits < 1) {
512 // One exponent bit is necessary to distinguish 0 from infinity. Having
513 // no exponent bits doesn't produce a sensible number, so we require at
514 // least one.
515 return InvalidArgument("Expected exponent_bits >= 1; got %d.",
516 exponent_bits);
517 }
518 if (mantissa_bits < 0) {
519 // A number with no mantissa bits is still meaningful, however.
520 return InvalidArgument("Expected non-negative mantissa_bits; got %d.",
521 mantissa_bits);
522 }
523 return operand_shape;
524 }
525
InferPadShape(const Shape & operand_shape,const Shape & padding_value_shape,const PaddingConfig & padding_config)526 /* static */ StatusOr<Shape> ShapeInference::InferPadShape(
527 const Shape& operand_shape, const Shape& padding_value_shape,
528 const PaddingConfig& padding_config) {
529 if (!operand_shape.IsArray()) {
530 return InvalidArgument(
531 "Pad operation does not support tuple-shape operands.");
532 }
533 if (!ShapeUtil::IsScalar(padding_value_shape)) {
534 return InvalidArgument(
535 "Pad operation does not support non-scalar padding values.");
536 }
537 if (operand_shape.rank() != padding_config.dimensions_size()) {
538 return InvalidArgument(
539 "The rank of the operand and the padding configuration do not match: "
540 "%s vs %s.",
541 ShapeUtil::HumanString(operand_shape),
542 padding_config.ShortDebugString());
543 }
544 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
545 padding_value_shape)) {
546 return InvalidArgument(
547 "The element types of the operands to Pad do not match.");
548 }
549 if (absl::c_any_of(padding_config.dimensions(),
550 [](const PaddingConfig::PaddingConfigDimension& p) {
551 return p.interior_padding() < 0;
552 })) {
553 return InvalidArgument("Interior padding cannot be negative: %s",
554 padding_config.ShortDebugString());
555 }
556
557 if (!padding_value_shape.is_static()) {
558 return InvalidArgument("Dynamic padding value is not supported");
559 }
560
561 std::vector<int64_t> dimensions(operand_shape.rank());
562 std::vector<bool> is_dynamic(operand_shape.rank());
563 for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
564 const auto& p = padding_config.dimensions(i);
565 dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() +
566 p.edge_padding_high() +
567 std::max<int64_t>(operand_shape.dimensions(i) - 1, 0LL) *
568 p.interior_padding();
569 if (dimensions[i] < 0) {
570 return InvalidArgument("Padding result in negative size for dimension %d",
571 i);
572 }
573 is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
574 }
575
576 return ShapeUtil::MakeShape(
577 ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
578 dimensions, is_dynamic);
579 }
580
581 // Current DotDimensionNumbers Requirements:
582 //
583 // Contracting Dimensions:
584 // *) Same number of contracting dimensions on both lhs and rhs.
585 // *) Contracting dimension size must be the same on both lhs and rhs.
586 //
587 // Batch Dimensions:
588 // *) Same number of batch dimensions on both lhs and rhs.
589 // *) Same batch dimension sizes on both lhs and rhs.
590 //
591
592 namespace {
593
ValidateDotDimensionNumbers(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers)594 Status ValidateDotDimensionNumbers(
595 const Shape& lhs, const Shape& rhs,
596 const DotDimensionNumbers& dimension_numbers) {
597 // Check that dimension numbers are in range.
598 auto dims_in_range = [](const int64_t rank,
599 absl::Span<const int64_t> contracting_dims,
600 absl::Span<const int64_t> batch_dims) -> bool {
601 auto in_range = [&rank](int64_t i) -> bool { return 0 <= i && i < rank; };
602 return absl::c_all_of(contracting_dims, in_range) &&
603 absl::c_all_of(batch_dims, in_range);
604 };
605
606 absl::Span<const int64_t> lhs_contracting_dimensions =
607 dimension_numbers.lhs_contracting_dimensions();
608 absl::Span<const int64_t> rhs_contracting_dimensions =
609 dimension_numbers.rhs_contracting_dimensions();
610 absl::Span<const int64_t> lhs_batch_dimensions =
611 dimension_numbers.lhs_batch_dimensions();
612 absl::Span<const int64_t> rhs_batch_dimensions =
613 dimension_numbers.rhs_batch_dimensions();
614
615 if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions,
616 lhs_batch_dimensions) ||
617 !dims_in_range(rhs.rank(), rhs_contracting_dimensions,
618 rhs_batch_dimensions)) {
619 return InvalidArgument("A dimension number is out of range in Dot: %s.",
620 dimension_numbers.DebugString());
621 }
622
623 // Check that dimension numbers are unique.
624 auto dims_unique = [](absl::Span<const int64_t> contracting_dims,
625 absl::Span<const int64_t> batch_dims) -> bool {
626 absl::flat_hash_set<int64_t> dim_set;
627 auto is_unique = [&dim_set](int64_t i) -> bool {
628 return dim_set.insert(i).second;
629 };
630 return absl::c_all_of(contracting_dims, is_unique) &&
631 absl::c_all_of(batch_dims, is_unique);
632 };
633
634 if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
635 !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
636 return InvalidArgument("A dimension number is not unique in Dot: %s.",
637 dimension_numbers.DebugString());
638 }
639
640 return OkStatus();
641 }
642
643 } // namespace
644
InferDotOpShape(const Shape & lhs,const Shape & rhs,const DotDimensionNumbers & dimension_numbers,std::optional<PrimitiveType> preferred_element_type)645 /* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
646 const Shape& lhs, const Shape& rhs,
647 const DotDimensionNumbers& dimension_numbers,
648 std::optional<PrimitiveType> preferred_element_type) {
649 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
650 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
651
652 auto fail = [lhs, rhs](const std::string& addendum) -> Status {
653 std::string message =
654 StrFormat("Cannot infer shape for dot operation: %s <dot> %s.",
655 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs));
656 if (!addendum.empty()) {
657 message += " " + addendum;
658 }
659 return InvalidArgument("%s", message);
660 };
661
662 // Validate basic properties of dot dimension numbers.
663 TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
664
665 // Check that number of contracting dimensions match.
666 if (dimension_numbers.lhs_contracting_dimensions_size() !=
667 dimension_numbers.rhs_contracting_dimensions_size()) {
668 return fail(
669 "Must specify the same number of contracting dimensions for lhs and "
670 "rhs.");
671 }
672 // Check that contracting dimension sizes match.
673 for (int64_t i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
674 ++i) {
675 const int64_t lhs_contracting_dimension =
676 dimension_numbers.lhs_contracting_dimensions(i);
677 const int64_t rhs_contracting_dimension =
678 dimension_numbers.rhs_contracting_dimensions(i);
679 if (lhs.dimensions(lhs_contracting_dimension) !=
680 rhs.dimensions(rhs_contracting_dimension)) {
681 return fail("Contracting dimension sizes do not match.");
682 }
683 }
684
685 // Check that number of batch dimensions match.
686 if (dimension_numbers.lhs_batch_dimensions_size() !=
687 dimension_numbers.rhs_batch_dimensions_size()) {
688 return fail("Must the same number of batch dimensions for lhs and rhs.");
689 }
690
691 // Check that batch dimension numbers and sizes match.
692 for (int64_t i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
693 if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
694 rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
695 return fail("Batch dimension sizes must match for lhs/rhs.");
696 }
697 }
698
699 // The ranks of lhs and rhs are decremented by 1 respectively due to the
700 // contraction, and added for the rank of the result. When an input tensor is
701 // a scalar, its contribution to the rank of the result is 0.
702 // Generate the result dimensions in order, rhs dimensions followed by lhs
703 // dimensions except the contracted and batch dimensions.
704 std::vector<int64_t> dimensions;
705 std::vector<bool> is_dynamic;
706 const auto& lhs_batch_dimensions = dimension_numbers.lhs_batch_dimensions();
707 const auto lhs_batch_dimensions_size =
708 lhs.rank() - dimension_numbers.lhs_contracting_dimensions().size() +
709 rhs.rank() - dimension_numbers.rhs_contracting_dimensions().size() -
710 dimension_numbers.rhs_batch_dimensions().size();
711 dimensions.reserve(lhs_batch_dimensions_size);
712 is_dynamic.reserve(lhs_batch_dimensions_size);
713 for (const int64_t lhs_dim : lhs_batch_dimensions) {
714 dimensions.push_back(lhs.dimensions(lhs_dim));
715 is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim));
716 }
717 for (int64_t i = 0; i < lhs.rank(); i++) {
718 if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
719 i) &&
720 !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
721 dimensions.push_back(lhs.dimensions(i));
722 is_dynamic.push_back(lhs.is_dynamic_dimension(i));
723 }
724 }
725 for (int64_t i = 0; i < rhs.rank(); i++) {
726 if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
727 i) &&
728 !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
729 dimensions.push_back(rhs.dimensions(i));
730 is_dynamic.push_back(rhs.is_dynamic_dimension(i));
731 }
732 }
733 TF_ASSIGN_OR_RETURN(
734 PrimitiveType type,
735 MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
736 preferred_element_type));
737 Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
738
739 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
740 VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
741 return result;
742 }
743
744 /* static */ StatusOr<Shape>
InferDegenerateDimensionBroadcastShape(HloOpcode operation,const Shape & lhs,const Shape & rhs)745 ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
746 const Shape& lhs,
747 const Shape& rhs) {
748 TF_RET_CHECK(lhs.rank() == rhs.rank());
749
750 // The shapes have to be compatible. That is, if some dimension d has a
751 // different size in the two shapes, one of them has to be 1 (a "degenerate"
752 // dimension). In that case, the output shape has the non-1 dimension size
753 // from the lhs/rhs pair in every index.
754 std::vector<int64_t> output_dimensions(lhs.rank());
755 std::vector<bool> output_dimensions_is_dynamic(lhs.rank());
756 for (int64_t i = 0; i < lhs.rank(); ++i) {
757 if (lhs.dimensions(i) == rhs.dimensions(i)) {
758 output_dimensions[i] = lhs.dimensions(i);
759 } else if (lhs.dimensions(i) == 1) {
760 output_dimensions[i] = rhs.dimensions(i);
761 } else if (rhs.dimensions(i) == 1) {
762 output_dimensions[i] = lhs.dimensions(i);
763 } else {
764 return InvalidArgument(
765 "Binary op %s with incompatible shapes: %s and %s.",
766 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
767 ShapeUtil::HumanString(rhs));
768 }
769 }
770
771 // Merge dynamic dimensions from two shapes.
772 for (int64_t i = 0; i < rhs.rank(); ++i) {
773 if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
774 output_dimensions_is_dynamic[i] = true;
775 }
776 }
777
778 return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
779 output_dimensions, output_dimensions_is_dynamic);
780 }
781
InferInDimBroadcastShape(const Shape & smaller_shape,const Shape & larger_shape,absl::Span<const int64_t> broadcast_dimensions)782 /* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
783 const Shape& smaller_shape, const Shape& larger_shape,
784 absl::Span<const int64_t> broadcast_dimensions) {
785 if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
786 // Reject "magic" inference for binops on different shapes, requiring
787 // the user to provide an explicit broadcast dimension in this case.
788 // See b/25177275 for more details.
789 return InvalidArgument("Shapes must be equal rank, but are %s and %s",
790 ShapeUtil::HumanString(smaller_shape),
791 ShapeUtil::HumanString(larger_shape));
792 } else if (broadcast_dimensions.size() != smaller_shape.rank()) {
793 return InvalidArgument(
794 "Size of broadcast_dimensions has to match lower-rank operand's "
795 "rank; "
796 " lower-rank operand's rank is %d, size of broadcast_dimensions is "
797 "%u.",
798 smaller_shape.rank(), broadcast_dimensions.size());
799 }
800
801 // broadcast_dimensions is a sequence of dimensions; its length is equal to
802 // the rank of the lower-rank operand. The lower-rank operand's dimensions
803 // have to be compatible with the higher-rank operand's dimensions at indices
804 // specified by broadcast_dimensions. Here compatible means the dimension
805 // sizes are equal or in one of the shapes the dimension size is
806 // one. Examples:
807 //
808 // smaller_shape larger_shape broadcast_dimensions output_shape
809 // [] [2, 3] {} [2, 3]
810 // [3] [4, 3] {1} [4, 3]
811 // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
812 // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
813 // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
814 //
815 // The column output_shape may not be the final shape of the XLA
816 // operation. After the "InDim" broadcasting implemented in this function
817 // expands the rank, degenerate-dimension broadcasting (implemented in
818 // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
819 // up to match the dimension size of the other operand. For example, consider
820 // the row in the table above with a smaller_shape of [2, 1]. The shape
821 // returned by this function is [2, 3, 1] (output_shape) however, the result
822 // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
823 // broadcasting.
824 //
825 // Invalid broadcasts:
826 //
827 // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
828 // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
829 // dimension zero of smaller_shape(size 3). **Zero here comes from the value
830 // in broadcast_dimensions.
831 //
832 // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
833 // Reason: Dimension one of larger_shape (size 3) is not compatible with
834 // dimension zero of smaller_shape(size 2)
835
836 // The output shape is initially the larger_shape. Sizes of dimensions
837 // specified in broadcast_dimensions are then changed to match the
838 // corresponding dimension size in smaller_shape.
839 Shape output_shape(larger_shape);
840 output_shape.set_element_type(
841 ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
842
843 for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
844 int64_t dimension_to_match = broadcast_dimensions.at(i);
845 if (dimension_to_match < 0) {
846 return InvalidArgument(
847 "Broadcast dimension number (%d) cannot be negative.",
848 dimension_to_match);
849 }
850 if (dimension_to_match >= larger_shape.dimensions_size()) {
851 return InvalidArgument(
852 "Broadcast dimension number (%d) too large; higher-rank "
853 "operand has rank %d.",
854 dimension_to_match, larger_shape.dimensions_size());
855 }
856 int64_t small_dimension_size = smaller_shape.dimensions(i);
857 int64_t large_dimension_size = larger_shape.dimensions(dimension_to_match);
858 bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i);
859 bool large_is_dynamic =
860 larger_shape.is_dynamic_dimension(dimension_to_match);
861 // Dimension sizes must be compatible: match or be degenerate (degenerate
862 // case is handled by degenerate dimension broadcasting which occurs after
863 // InDim broadcasting).
864 if (small_dimension_size != large_dimension_size &&
865 small_dimension_size != 1 && large_dimension_size != 1) {
866 return InvalidArgument(
867 "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i,
868 small_dimension_size, large_dimension_size,
869 ShapeUtil::HumanString(smaller_shape),
870 ShapeUtil::HumanString(larger_shape));
871 }
872 if (small_is_dynamic != large_is_dynamic) {
873 if (small_dimension_size == large_dimension_size ||
874 (small_dimension_size == 1 && !small_is_dynamic) ||
875 (large_dimension_size == 1 && !large_is_dynamic)) {
876 // Do nothing. It's OK when the size-1 dimension is not static.
877 } else {
878 return InvalidArgument(
879 "Broadcast dimension %d dynamism mismatch: %s and %s.", i,
880 ShapeUtil::HumanString(smaller_shape),
881 ShapeUtil::HumanString(larger_shape));
882 }
883 }
884 // Make sure the broadcast dimensions are listed in a strictly increasing
885 // order.
886 if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
887 return InvalidArgument(
888 "Broadcast dimensions order is wrong: %d comes after %d.",
889 dimension_to_match, broadcast_dimensions.at(i - 1));
890 }
891
892 output_shape.set_dimensions(dimension_to_match, small_dimension_size);
893 output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic);
894 }
895
896 return output_shape;
897 }
898
InferElementwiseBinaryOpShape(HloOpcode operation,const Shape & lhs,const Shape & rhs,absl::Span<const int64_t> broadcast_dimensions)899 /* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
900 HloOpcode operation, const Shape& lhs, const Shape& rhs,
901 absl::Span<const int64_t> broadcast_dimensions) {
902 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
903 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
904
905 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
906 return InvalidArgument(
907 "Binary op %s with different element types: %s and %s.",
908 HloOpcodeString(operation), ShapeUtil::HumanString(lhs),
909 ShapeUtil::HumanString(rhs));
910 }
911
912 if (lhs.rank() == rhs.rank()) {
913 std::vector<int64_t> identity_dims(lhs.rank());
914 std::iota(identity_dims.begin(), identity_dims.end(), 0);
915 if (!broadcast_dimensions.empty() &&
916 broadcast_dimensions != identity_dims) {
917 return InvalidArgument(
918 "Broadcast dimensions field must either be not set or be the "
919 "identity on binary operations with operands of the same rank.");
920 }
921 }
922
923 if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
924 // If the shapes are the same other than layout, the output shape is the
925 // same (elementwise op).
926 Shape result = ShapeUtil::ChangeElementType(
927 lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
928
929 for (int64_t i = 0; i < rhs.rank(); ++i) {
930 if (rhs.is_dynamic_dimension(i)) {
931 result.set_dynamic_dimension(i, true);
932 }
933 }
934
935 return result;
936
937 } else if (lhs.rank() == rhs.rank()) {
938 return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
939 } else {
940 // Ranks do not match, so perform InDim broadcasting using
941 // broadcast_dimensions. Scalar broadcasting is a special case of this.
942 const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs;
943 const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs;
944
945 // After InDim broadcasting, perform degenerate dimensions broadcasting.
946 TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
947 InferInDimBroadcastShape(smaller_shape, larger_shape,
948 broadcast_dimensions));
949
950 return InferDegenerateDimensionBroadcastShape(
951 operation, indim_broadcast_shape, larger_shape);
952 }
953 }
954
InferBinaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs)955 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
956 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
957 return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
958 /*broadcast_dimensions=*/{});
959 }
960
InferBinaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,absl::Span<const int64_t> broadcast_dimensions)961 /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
962 HloOpcode opcode, const Shape& lhs, const Shape& rhs,
963 absl::Span<const int64_t> broadcast_dimensions) {
964 VLOG(2) << StrFormat(
965 "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
966 HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
967 ShapeUtil::HumanStringWithLayout(rhs),
968 StrJoin(broadcast_dimensions, ", "));
969
970 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
971 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
972
973 TF_RETURN_IF_ERROR(ExpectArray(
974 lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode))));
975 TF_RETURN_IF_ERROR(ExpectArray(
976 rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode))));
977 switch (opcode) {
978 case HloOpcode::kAdd:
979 case HloOpcode::kMaximum:
980 case HloOpcode::kMinimum:
981 case HloOpcode::kMultiply:
982 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
983 broadcast_dimensions);
984
985 case HloOpcode::kSubtract:
986 case HloOpcode::kAtan2:
987 case HloOpcode::kPower:
988 case HloOpcode::kDivide:
989 case HloOpcode::kRemainder:
990 case HloOpcode::kShiftLeft:
991 case HloOpcode::kShiftRightArithmetic:
992 case HloOpcode::kShiftRightLogical:
993 if (lhs.element_type() == PRED || rhs.element_type() == PRED) {
994 return InvalidArgument(
995 "Expected element type in shape to be arithmetic type for "
996 "operation %s; got PRED.",
997 HloOpcodeString(opcode));
998 }
999 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1000 broadcast_dimensions);
1001
1002 case HloOpcode::kComplex: {
1003 if (!ShapeUtil::ElementIsFloating(lhs)) {
1004 return InvalidArgument(
1005 "Expected element type in shape to be floating for complex compose "
1006 "operation; got %s.",
1007 PrimitiveType_Name(lhs.element_type()));
1008 }
1009 TF_ASSIGN_OR_RETURN(const Shape& shape,
1010 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1011 broadcast_dimensions));
1012 if (lhs.element_type() == F32 && rhs.element_type() == F32) {
1013 return ShapeUtil::ChangeElementType(shape, C64);
1014 } else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
1015 return ShapeUtil::ChangeElementType(shape, C128);
1016 } else {
1017 return Unimplemented("Complex component type is not implemented.");
1018 }
1019 }
1020 case HloOpcode::kAnd:
1021 case HloOpcode::kOr:
1022 case HloOpcode::kXor:
1023 if (lhs.element_type() != PRED &&
1024 !primitive_util::IsIntegralType(lhs.element_type())) {
1025 return InvalidArgument(
1026 "Expected pred or integral type in argument to and/or operation; "
1027 "got %s.",
1028 PrimitiveType_Name(lhs.element_type()));
1029 }
1030 return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1031 broadcast_dimensions);
1032 case HloOpcode::kCompare: {
1033 TF_ASSIGN_OR_RETURN(const Shape& shape,
1034 InferElementwiseBinaryOpShape(opcode, lhs, rhs,
1035 broadcast_dimensions));
1036 return ShapeUtil::ChangeElementType(shape, PRED);
1037 }
1038 default:
1039 return Unimplemented(
1040 "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
1041 HloOpcodeString(opcode), lhs.ShortDebugString(),
1042 rhs.ShortDebugString());
1043 }
1044 }
1045
InferTernaryOpShape(HloOpcode opcode,const HloInstruction * lhs,const HloInstruction * rhs,const HloInstruction * ehs)1046 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1047 HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs,
1048 const HloInstruction* ehs) {
1049 return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape());
1050 }
1051
InferTernaryOpShape(HloOpcode opcode,const Shape & lhs,const Shape & rhs,const Shape & ehs)1052 /* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
1053 HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
1054 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1055 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1056 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
1057 switch (opcode) {
1058 case HloOpcode::kClamp:
1059 return InferClampShape(lhs, rhs, ehs);
1060 case HloOpcode::kSelect:
1061 return InferSelectShape(lhs, rhs, ehs);
1062 default:
1063 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1064 }
1065 }
1066
InferVariadicOpShape(HloOpcode opcode,absl::Span<const HloInstruction * const> operands)1067 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1068 HloOpcode opcode, absl::Span<const HloInstruction* const> operands) {
1069 std::vector<const Shape*> operand_shapes;
1070 operand_shapes.reserve(operands.size());
1071 for (const HloInstruction* operand : operands) {
1072 operand_shapes.push_back(&operand->shape());
1073 }
1074 return InferVariadicOpShape(opcode, operand_shapes);
1075 }
1076
InferVariadicOpShape(HloOpcode opcode,absl::Span<const Shape * const> operand_shapes)1077 /* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
1078 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes) {
1079 for (const Shape* shape : operand_shapes) {
1080 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
1081 }
1082 switch (opcode) {
1083 case HloOpcode::kTuple: {
1084 Shape result = ShapeUtil::MakeTupleShape({});
1085 result.mutable_tuple_shapes()->reserve(operand_shapes.size());
1086 for (const Shape* shape : operand_shapes) {
1087 ShapeUtil::AppendShapeToTuple(*shape, &result);
1088 }
1089 return result;
1090 }
1091 case HloOpcode::kSort: {
1092 if (operand_shapes.size() == 1) {
1093 return *operand_shapes[0];
1094 } else {
1095 for (int64_t operand = 1; operand < operand_shapes.size(); ++operand) {
1096 if (!ShapeUtil::SameDimensions(*operand_shapes[0],
1097 *operand_shapes[operand])) {
1098 return InvalidArgument(
1099 "Sort keys and values dimensions must match. "
1100 "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
1101 ShapeUtil::HumanString(*operand_shapes[0]), operand,
1102 ShapeUtil::HumanString(*operand_shapes[operand]));
1103 }
1104 }
1105 return ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
1106 }
1107 return InvalidArgument("Unexpected number of operands for sort");
1108 }
1109 default:
1110 return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode));
1111 }
1112 }
1113
InferMapShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply,absl::Span<const int64_t> dimensions)1114 /* static */ StatusOr<Shape> ShapeInference::InferMapShape(
1115 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
1116 absl::Span<const int64_t> dimensions) {
1117 if (arg_shapes.empty()) {
1118 return InvalidArgument("Map expects at least one argument.");
1119 }
1120
1121 // All arguments must have the same shape ignoring the element types.
1122 const Shape* arg_shape = arg_shapes[0];
1123 for (size_t i = 1; i < arg_shapes.size(); ++i) {
1124 TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
1125
1126 if (ShapeUtil::CompatibleIgnoringElementType(*arg_shapes[i], *arg_shape)) {
1127 continue;
1128 }
1129 if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
1130 *arg_shape)) {
1131 if (ShapeUtil::IsScalar(*arg_shapes[i])) {
1132 continue;
1133 }
1134 if (ShapeUtil::IsScalar(*arg_shape)) {
1135 arg_shape = arg_shapes[i];
1136 continue;
1137 }
1138 }
1139
1140 std::vector<std::string> pieces;
1141 pieces.reserve(arg_shapes.size());
1142 for (const Shape* shape : arg_shapes) {
1143 pieces.push_back(ShapeUtil::HumanString(*shape));
1144 }
1145 return InvalidArgument(
1146 "Map operation requires all operands to have the same shape; got: "
1147 "%s.",
1148 StrJoin(pieces, ", "));
1149 }
1150
1151 // Check that dimensions.size == arg_shape.dimensions_size() (we currently
1152 // only support mapping across all dimensions: i.e. scalar map functions).
1153 if (dimensions.size() != arg_shape->dimensions_size()) {
1154 return InvalidArgument(
1155 "Map applied to a subset of dimensions currently not supported: "
1156 "arg_dimension_size: %d, requested_map_dimensions_size: %u.",
1157 arg_shape->dimensions_size(), dimensions.size());
1158 }
1159
1160 // Check that requested map dimensions numbers are monotonically increasing.
1161 for (int i = 0; i < dimensions.size(); ++i) {
1162 if (dimensions[i] != i) {
1163 return InvalidArgument(
1164 "Map requires monotonically increasing dimension numbers; got: %s.",
1165 StrJoin(dimensions, ", "));
1166 }
1167 }
1168
1169 // The applied function's arity equals the number of arguments.
1170 if (arg_shapes.size() != to_apply.parameters_size()) {
1171 return InvalidArgument(
1172 "Map applied function arity must match number of arguments; got: "
1173 "arity: %d, arguments: %u.",
1174 to_apply.parameters_size(), arg_shapes.size());
1175 }
1176
1177 // The parameters should all be scalars, and the output too.
1178 const Shape& output_shape = to_apply.result();
1179 if (!ShapeUtil::IsScalar(output_shape)) {
1180 return InvalidArgument(
1181 "Mapped computation's result has to be a scalar; got: %s.",
1182 ShapeUtil::HumanString(output_shape));
1183 }
1184
1185 for (int i = 0; i < to_apply.parameters_size(); ++i) {
1186 const Shape& parameter_shape = to_apply.parameters(i);
1187
1188 if (!ShapeUtil::IsScalar(parameter_shape)) {
1189 return InvalidArgument(
1190 "Mapped computation's parameter has to be a scalar; "
1191 "got parameter %d shape: %s.",
1192 i, ShapeUtil::HumanString(parameter_shape));
1193 }
1194
1195 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
1196 *arg_shapes[i])) {
1197 return InvalidArgument(
1198 "Mapped computation's parameter type has to match argument element "
1199 "type; got parameter %d shape: %s, argument shape: %s.",
1200 i, ShapeUtil::HumanString(parameter_shape),
1201 ShapeUtil::HumanString(*arg_shape));
1202 }
1203 }
1204
1205 return ShapeUtil::MakeShape(output_shape.element_type(),
1206 arg_shape->dimensions());
1207 }
1208
InferBatchNormTrainingShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,int64_t feature_index)1209 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormTrainingShape(
1210 const Shape& operand_shape, const Shape& scale_shape,
1211 const Shape& offset_shape, int64_t feature_index) {
1212 TF_RETURN_IF_ERROR(
1213 ExpectArray(operand_shape, "operand of batch norm training"));
1214 TF_RETURN_IF_ERROR(
1215 ExpectArray(offset_shape, "offset input of batch norm training"));
1216 TF_RETURN_IF_ERROR(
1217 ExpectArray(scale_shape, "scale input of batch norm training"));
1218
1219 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
1220 OkStatus());
1221 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
1222 OkStatus());
1223 TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
1224 OkStatus());
1225
1226 if (feature_index >= operand_shape.rank()) {
1227 return InvalidArgument(
1228 "Expected feature_index of batch-norm-training to be "
1229 "smaller than the rank of operand_shape; "
1230 "got feature_index %d, and rank %d.",
1231 feature_index, operand_shape.rank());
1232 }
1233
1234 if (feature_index < 0) {
1235 return InvalidArgument(
1236 "Expected feature_index of batch-norm-training to "
1237 "be a non-negative number, got %d.",
1238 feature_index);
1239 }
1240
1241 if (operand_shape.rank() < 1) {
1242 return InvalidArgument(
1243 "Expected the rank of operand to "
1244 "batch-norm-training to be at least 1; got %d.",
1245 operand_shape.rank());
1246 }
1247
1248 if (offset_shape.rank() != 1) {
1249 return InvalidArgument(
1250 "Offset input of batch-norm-training must have"
1251 " rank 1, but has rank %d.",
1252 offset_shape.rank());
1253 }
1254
1255 if (scale_shape.rank() != 1) {
1256 return InvalidArgument(
1257 "Scale input of batch-norm-training must have"
1258 " rank 1, but has rank %d.",
1259 scale_shape.rank());
1260 }
1261
1262 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1263 return InvalidArgument(
1264 "The operand to batch-norm-training must have a floating point "
1265 "element type, but the shape is %s.",
1266 PrimitiveType_Name(operand_shape.element_type()));
1267 }
1268
1269 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1270 operand_shape)) {
1271 return InvalidArgument(
1272 "The inputs should have the same element type for batch-norm-training, "
1273 "but the shape of offset factor is %s "
1274 "and the shape of operand is %s.",
1275 PrimitiveType_Name(offset_shape.element_type()),
1276 PrimitiveType_Name(operand_shape.element_type()));
1277 }
1278
1279 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1280 operand_shape)) {
1281 return InvalidArgument(
1282 "The inputs should have the same element type for batch-norm-training, "
1283 "but the shape of scale factor is %s "
1284 "and the shape of operand is %s.",
1285 PrimitiveType_Name(scale_shape.element_type()),
1286 PrimitiveType_Name(operand_shape.element_type()));
1287 }
1288
1289 const int64_t feature_count = operand_shape.dimensions(feature_index);
1290 Shape output_shape_for_mean_and_var =
1291 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1292
1293 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1294 return InvalidArgument(
1295 "The size of offset factor should be the same as feature count,"
1296 "but the size of offset factor is %d "
1297 "and the feature count is %d.",
1298 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1299 }
1300
1301 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1302 return InvalidArgument(
1303 "The size of scale factor should be the same as feature count,"
1304 "but the size of scale factor is %d "
1305 "and the feature count is %d.",
1306 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1307 }
1308
1309 return ShapeUtil::MakeTupleShapeWithPtrs({&operand_shape,
1310 &output_shape_for_mean_and_var,
1311 &output_shape_for_mean_and_var});
1312 }
1313
InferBatchNormInferenceShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & offset_shape,const Shape & mean_shape,const Shape & variance_shape,int64_t feature_index)1314 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
1315 const Shape& operand_shape, const Shape& scale_shape,
1316 const Shape& offset_shape, const Shape& mean_shape,
1317 const Shape& variance_shape, int64_t feature_index) {
1318 TF_RETURN_IF_ERROR(
1319 ExpectArray(operand_shape, "operand of batch norm inference"));
1320 TF_RETURN_IF_ERROR(
1321 ExpectArray(offset_shape, "offset input of batch norm inference"));
1322 TF_RETURN_IF_ERROR(
1323 ExpectArray(scale_shape, "scale input of batch norm inference"));
1324
1325 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1326 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape));
1327 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1328 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1329 TF_RETURN_IF_ERROR(
1330 ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape));
1331
1332 if (feature_index >= operand_shape.rank()) {
1333 return InvalidArgument(
1334 "Expected feature_index of batch-norm-inference to be "
1335 "smaller than the rank of operand_shape; "
1336 "got feature_index %d, and rank %d.",
1337 feature_index, operand_shape.rank());
1338 }
1339
1340 if (feature_index < 0) {
1341 return InvalidArgument(
1342 "Expected feature_index of batch-norm-inference to "
1343 "be a non-negative number, got %d.",
1344 feature_index);
1345 }
1346
1347 if (operand_shape.rank() < 1) {
1348 return InvalidArgument(
1349 "Expected the rank of operand to "
1350 "batch-norm-inference to be at least 1; got %d.",
1351 operand_shape.rank());
1352 }
1353
1354 if (offset_shape.rank() != 1) {
1355 return InvalidArgument(
1356 "Offset input of batch-norm-inference must have"
1357 " rank 1, but has rank %d.",
1358 offset_shape.rank());
1359 }
1360
1361 if (scale_shape.rank() != 1) {
1362 return InvalidArgument(
1363 "Scale input of batch-norm-inference must have"
1364 " rank 1, but has rank %d.",
1365 scale_shape.rank());
1366 }
1367
1368 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1369 return InvalidArgument(
1370 "The operand to batch-norm-inference must have a floating point "
1371 "element type, but the shape is %s.",
1372 PrimitiveType_Name(operand_shape.element_type()));
1373 }
1374
1375 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
1376 operand_shape)) {
1377 return InvalidArgument(
1378 "The inputs should have the same element type for "
1379 "batch-norm-inference, "
1380 "but the shape of offset factor is %s "
1381 "and the shape of operand is %s.",
1382 PrimitiveType_Name(offset_shape.element_type()),
1383 PrimitiveType_Name(operand_shape.element_type()));
1384 }
1385
1386 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1387 operand_shape)) {
1388 return InvalidArgument(
1389 "The inputs should have the same element type for "
1390 "batch-norm-inference, "
1391 "but the shape of scale factor is %s "
1392 "and the shape of operand is %s.",
1393 PrimitiveType_Name(scale_shape.element_type()),
1394 PrimitiveType_Name(operand_shape.element_type()));
1395 }
1396
1397 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1398 operand_shape)) {
1399 return InvalidArgument(
1400 "The inputs should have the same element type for "
1401 "batch-norm-inference, "
1402 "but the shape of mean is %s "
1403 "and the shape of operand is %s.",
1404 PrimitiveType_Name(mean_shape.element_type()),
1405 PrimitiveType_Name(operand_shape.element_type()));
1406 }
1407
1408 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
1409 operand_shape)) {
1410 return InvalidArgument(
1411 "The inputs should have the same element type for "
1412 "batch-norm-inference, "
1413 "but the shape of variance is %s "
1414 "and the shape of operand is %s.",
1415 PrimitiveType_Name(mean_shape.element_type()),
1416 PrimitiveType_Name(variance_shape.element_type()));
1417 }
1418
1419 const int64_t feature_count = operand_shape.dimensions(feature_index);
1420 Shape output_shape_for_mean_and_var =
1421 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1422
1423 if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
1424 return InvalidArgument(
1425 "The size of offset factor should be the same as feature count,"
1426 "but the size of offset factor is %d "
1427 "and the feature count is %d.",
1428 ShapeUtil::GetDimension(offset_shape, 0), feature_count);
1429 }
1430
1431 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1432 return InvalidArgument(
1433 "The size of scale factor should be the same as feature count,"
1434 "but the size of scale factor is %d "
1435 "and the feature count is %d.",
1436 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1437 }
1438
1439 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1440 return InvalidArgument(
1441 "The size of mean should be the same as feature count,"
1442 "but the size of mean is %d "
1443 "and the feature count is %d.",
1444 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1445 }
1446
1447 if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
1448 return InvalidArgument(
1449 "The size of variance should be the same as feature count,"
1450 "but the size of variance is %d "
1451 "and the feature count is %d.",
1452 ShapeUtil::GetDimension(variance_shape, 0), feature_count);
1453 }
1454
1455 return operand_shape;
1456 }
1457
InferBatchNormGradShape(const Shape & operand_shape,const Shape & scale_shape,const Shape & mean_shape,const Shape & var_shape,const Shape & output_grad_shape,int64_t feature_index)1458 /* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
1459 const Shape& operand_shape, const Shape& scale_shape,
1460 const Shape& mean_shape, const Shape& var_shape,
1461 const Shape& output_grad_shape, int64_t feature_index) {
1462 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
1463 TF_RETURN_IF_ERROR(
1464 ExpectArray(scale_shape, "scale input of batch norm grad"));
1465 TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
1466 TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
1467 TF_RETURN_IF_ERROR(
1468 ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
1469
1470 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
1471 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
1472 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape));
1473 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape));
1474 TF_RETURN_IF_ERROR(
1475 ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape));
1476
1477 if (feature_index >= operand_shape.rank()) {
1478 return InvalidArgument(
1479 "Expected feature_index of batch-norm-grad to be "
1480 "smaller than the rank of operand_shape; "
1481 "got feature_index %d, and rank %d.",
1482 feature_index, operand_shape.rank());
1483 }
1484
1485 if (operand_shape.rank() != output_grad_shape.rank()) {
1486 return InvalidArgument(
1487 "Expected operand_shape of batch-norm-grad to have the same rank as"
1488 " output_grad_shape; got rank(oprand_shape) %d, and"
1489 " rank(output_grad_shape) %d.",
1490 operand_shape.rank(), output_grad_shape.rank());
1491 }
1492
1493 if (mean_shape.rank() != 1) {
1494 return InvalidArgument(
1495 "Mean input of batch-norm-grad must have"
1496 " rank 1, but has rank %d.",
1497 mean_shape.rank());
1498 }
1499
1500 if (scale_shape.rank() != 1) {
1501 return InvalidArgument(
1502 "Scale input of batch-norm-grad must have"
1503 " rank 1, but has rank %d.",
1504 scale_shape.rank());
1505 }
1506
1507 if (var_shape.rank() != 1) {
1508 return InvalidArgument(
1509 "Var input of batch-norm-grad must have"
1510 " rank 1, but has rank %d.",
1511 var_shape.rank());
1512 }
1513
1514 if (!ShapeUtil::ElementIsFloating(operand_shape)) {
1515 return InvalidArgument(
1516 "The operand to batch-norm-grad must have a floating point "
1517 "element type, but the shape is %s.",
1518 PrimitiveType_Name(operand_shape.element_type()));
1519 }
1520
1521 if (!ShapeUtil::ElementIsFloating(output_grad_shape)) {
1522 return InvalidArgument(
1523 "The output_grad to batch-norm-grad must have a floating point "
1524 "element type, but the shape is %s.",
1525 PrimitiveType_Name(output_grad_shape.element_type()));
1526 }
1527
1528 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
1529 operand_shape)) {
1530 return InvalidArgument(
1531 "The inputs should have the same element type for batch-norm-grad, "
1532 "but the element type of output_grad is %s "
1533 "and the element type of operand is %s.",
1534 PrimitiveType_Name(output_grad_shape.element_type()),
1535 PrimitiveType_Name(operand_shape.element_type()));
1536 }
1537
1538 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
1539 operand_shape)) {
1540 return InvalidArgument(
1541 "The inputs should have the same element type for batch-norm-grad, "
1542 "but the element type of scale factor is %s "
1543 "and the element type of operand is %s.",
1544 PrimitiveType_Name(scale_shape.element_type()),
1545 PrimitiveType_Name(operand_shape.element_type()));
1546 }
1547
1548 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
1549 operand_shape)) {
1550 return InvalidArgument(
1551 "The inputs should have the same element type for batch-norm-grad, "
1552 "but the element type of mean is %s "
1553 "and the element type of operand is %s.",
1554 PrimitiveType_Name(mean_shape.element_type()),
1555 PrimitiveType_Name(operand_shape.element_type()));
1556 }
1557
1558 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
1559 operand_shape)) {
1560 return InvalidArgument(
1561 "The inputs should have the same element type for batch-norm-grad, "
1562 "but the element type of mean is %s "
1563 "and the element type of operand is %s.",
1564 PrimitiveType_Name(mean_shape.element_type()),
1565 PrimitiveType_Name(operand_shape.element_type()));
1566 }
1567
1568 const int64_t feature_count = operand_shape.dimensions(feature_index);
1569
1570 Shape feature_shape =
1571 ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
1572
1573 if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
1574 return InvalidArgument(
1575 "The size of mean should be the same as feature count,"
1576 "but the size of offset factor is %d "
1577 "and the feature count is %d.",
1578 ShapeUtil::GetDimension(mean_shape, 0), feature_count);
1579 }
1580
1581 if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
1582 return InvalidArgument(
1583 "The size of scale factor should be the same as feature count,"
1584 "but the size of scale factor is %d "
1585 "and the feature count is %d.",
1586 ShapeUtil::GetDimension(scale_shape, 0), feature_count);
1587 }
1588
1589 if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) {
1590 return InvalidArgument(
1591 "The size of variance should be the same as feature count,"
1592 "but the size of variance is %d "
1593 "and the feature count is %d.",
1594 ShapeUtil::GetDimension(var_shape, 0), feature_count);
1595 }
1596
1597 // Verify operand_shape and output_grad_shape have same bounds.
1598 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
1599 if (ShapeUtil::GetDimension(operand_shape, i) !=
1600 ShapeUtil::GetDimension(output_grad_shape, i)) {
1601 return InvalidArgument(
1602 "The bounds of operand shape should be the same as output_grad's,"
1603 "but the bound of operand_shape at dimension %d is %d "
1604 "and the bound of output_grad_shape is %d.",
1605 i, ShapeUtil::GetDimension(operand_shape, i),
1606 ShapeUtil::GetDimension(output_grad_shape, i));
1607 }
1608 }
1609
1610 return ShapeUtil::MakeTupleShapeWithPtrs(
1611 {&operand_shape, &feature_shape, &feature_shape});
1612 }
1613
InferConvolveShape(const Shape & lhs,const Shape & rhs,int64_t feature_group_count,int64_t batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dnums,std::optional<PrimitiveType> preferred_element_type)1614 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
1615 const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
1616 int64_t batch_group_count, const Window& window,
1617 const ConvolutionDimensionNumbers& dnums,
1618 std::optional<PrimitiveType> preferred_element_type) {
1619 TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
1620 TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
1621
1622 if (feature_group_count <= 0) {
1623 return InvalidArgument(
1624 "feature_group_count must be a positive number, got %d",
1625 feature_group_count);
1626 }
1627
1628 if (batch_group_count <= 0) {
1629 return InvalidArgument(
1630 "batch_group_count must be a positive number, got %d",
1631 batch_group_count);
1632 }
1633
1634 if (batch_group_count > 1 && feature_group_count > 1) {
1635 return InvalidArgument(
1636 "both batch_group_count %d and feature_group_count %d cannot be "
1637 "greater than 1",
1638 batch_group_count, feature_group_count);
1639 }
1640
1641 if (dnums.input_spatial_dimensions_size() !=
1642 dnums.kernel_spatial_dimensions_size()) {
1643 return InvalidArgument(
1644 "Both arguments to convolution must have same number of dimensions.\n"
1645 "Numbers: %s",
1646 dnums.DebugString());
1647 }
1648
1649 if (dnums.input_spatial_dimensions_size() !=
1650 dnums.output_spatial_dimensions_size()) {
1651 return InvalidArgument(
1652 "Both input and output of convolution must have same number of "
1653 "dimensions.\nNumbers: %s",
1654 dnums.DebugString());
1655 }
1656
1657 const int num_spatial_dims = dnums.input_spatial_dimensions_size();
1658 if (window.dimensions_size() != num_spatial_dims) {
1659 return InvalidArgument(
1660 "Window must have same number of dimensions as dimension numbers.\n"
1661 "Window: %s\nDimension numbers: %s.",
1662 window.DebugString(), dnums.DebugString());
1663 }
1664
1665 const int num_dims = num_spatial_dims + 2;
1666 if (lhs.rank() != num_dims) {
1667 return InvalidArgument(
1668 "The LHS argument to a convolution should have rank %d; lhs: %s.",
1669 num_dims, ShapeUtil::HumanString(lhs));
1670 }
1671 if (rhs.rank() != num_dims) {
1672 return InvalidArgument(
1673 "The RHS argument to a convolution should have rank %d; rhs: %s.",
1674 num_dims, ShapeUtil::HumanString(rhs));
1675 }
1676 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
1677 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
1678
1679 // Verifies that the input and window dimensions are a permutation of
1680 // the dimension numbers.
1681 std::vector<int64_t> input_dnums(num_dims);
1682 input_dnums[0] = dnums.input_batch_dimension();
1683 input_dnums[1] = dnums.input_feature_dimension();
1684 std::copy(dnums.input_spatial_dimensions().begin(),
1685 dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
1686 absl::c_sort(input_dnums);
1687
1688 std::vector<int64_t> window_dnums(num_dims);
1689 window_dnums[0] = dnums.kernel_input_feature_dimension();
1690 window_dnums[1] = dnums.kernel_output_feature_dimension();
1691 std::copy(dnums.kernel_spatial_dimensions().begin(),
1692 dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
1693 absl::c_sort(window_dnums);
1694
1695 std::vector<int64_t> output_dnums(num_dims);
1696 output_dnums[0] = dnums.output_batch_dimension();
1697 output_dnums[1] = dnums.output_feature_dimension();
1698 std::copy(dnums.output_spatial_dimensions().begin(),
1699 dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
1700 absl::c_sort(output_dnums);
1701
1702 std::vector<int64_t> expected_dnums(num_dims);
1703 std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
1704
1705 const auto in_range = [num_dims](int64_t i) {
1706 return 0 <= i && i < num_dims;
1707 };
1708 if (!absl::c_all_of(input_dnums, in_range) ||
1709 !absl::c_all_of(window_dnums, in_range) ||
1710 !absl::c_all_of(output_dnums, in_range)) {
1711 return InvalidArgument(
1712 "A dimension number is out of range in convolution: %s.",
1713 dnums.DebugString());
1714 }
1715
1716 if (input_dnums != expected_dnums) {
1717 return InvalidArgument(
1718 "Input dimensions of convolution must contain each dimension exactly "
1719 "once: %s.",
1720 dnums.DebugString());
1721 }
1722 if (window_dnums != expected_dnums) {
1723 return InvalidArgument(
1724 "Window dimensions of convolution must contain each dimension exactly "
1725 "once: %s.",
1726 dnums.DebugString());
1727 }
1728 if (output_dnums != expected_dnums) {
1729 return InvalidArgument(
1730 "Output dimensions of convolution must contain each dimension exactly "
1731 "once: %s.",
1732 dnums.DebugString());
1733 }
1734
1735 std::vector<int64_t> input_spatial_dims(num_spatial_dims);
1736 for (int i = 0; i < num_spatial_dims; ++i) {
1737 input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i));
1738 }
1739 const int64_t input_features =
1740 lhs.dimensions(dnums.input_feature_dimension());
1741 const int64_t input_batch = lhs.dimensions(dnums.input_batch_dimension());
1742
1743 std::vector<int64_t> kernel_spatial_dims(num_spatial_dims);
1744 for (int i = 0; i < num_spatial_dims; ++i) {
1745 kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
1746 }
1747 const int64_t kernel_input_features =
1748 rhs.dimensions(dnums.kernel_input_feature_dimension());
1749 const int64_t kernel_output_features =
1750 rhs.dimensions(dnums.kernel_output_feature_dimension());
1751
1752 if (kernel_output_features % batch_group_count != 0) {
1753 return InvalidArgument(
1754 "Expected output feature dimension size (value %d) to be a multiple of "
1755 "batch group count %d; got <conv>(%s, %s)\n"
1756 "Dimension numbers: {%s}.",
1757 kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs),
1758 ShapeUtil::HumanString(rhs), dnums.DebugString());
1759 }
1760
1761 if (input_features % feature_group_count != 0 ||
1762 input_features / feature_group_count != kernel_input_features) {
1763 return InvalidArgument(
1764 "Expected LHS feature dimension (value %d) to be a multiple of "
1765 "feature_group_count (value %d), and LHS feature dimension / "
1766 "feature_group_count = RHS feature dimension (value %d); "
1767 "got <conv>(%s, %s)\n"
1768 "Dimension numbers: {%s}.",
1769 input_features, feature_group_count, kernel_input_features,
1770 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1771 dnums.DebugString());
1772 }
1773
1774 if (kernel_output_features % feature_group_count > 0) {
1775 // A depthwise/grouped filter has the shape
1776 // [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
1777 // [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
1778 // [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
1779 // feature count (which is equal to kernel output features) has to be a
1780 // multiple of feature_group_count.
1781 return InvalidArgument(
1782 "Expected output feature dimension (value %d) to be divisible by "
1783 "feature_group_count (value %d); "
1784 "got <conv>(%s, %s)\n"
1785 "Dimension numbers: {%s}.",
1786 kernel_output_features, feature_group_count,
1787 ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
1788 dnums.DebugString());
1789 }
1790
1791 if (input_batch % batch_group_count != 0) {
1792 return InvalidArgument(
1793 "Expected input batch dimension (value %d) to be divisible by "
1794 "batch_group_count (value %d); "
1795 "got <conv>(%s, %s)\n"
1796 "Dimension numbers: {%s}.",
1797 input_batch, batch_group_count, ShapeUtil::HumanString(lhs),
1798 ShapeUtil::HumanString(rhs), dnums.DebugString());
1799 }
1800
1801 std::vector<int64_t> window_dims(num_spatial_dims);
1802 for (int i = 0; i < num_spatial_dims; ++i) {
1803 window_dims[i] = window.dimensions(i).size();
1804 }
1805 if (kernel_spatial_dims != window_dims) {
1806 return InvalidArgument(
1807 "Window dimensions do not match RHS shape:\n\t"
1808 "RHS shape: %s\n\t"
1809 "Window: {%s}\n\t"
1810 "Dimension numbers: {%s}.",
1811 ShapeUtil::HumanString(rhs), window.ShortDebugString(),
1812 dnums.ShortDebugString());
1813 }
1814
1815 Shape base_shape =
1816 ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
1817 TF_ASSIGN_OR_RETURN(
1818 Shape window_output_shape,
1819 InferWindowOutputShape(base_shape, window, lhs.element_type()));
1820
1821 std::vector<int64_t> dimensions(num_dims);
1822 dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1823 dimensions[dnums.output_feature_dimension()] = kernel_output_features;
1824
1825 for (int i = 0; i < num_spatial_dims; ++i) {
1826 dimensions[dnums.output_spatial_dimensions(i)] =
1827 window_output_shape.dimensions(i);
1828 }
1829 std::vector<bool> is_dynamic(num_dims);
1830 for (int i = 0; i < num_dims; i++) {
1831 if (lhs.is_dynamic_dimension(i)) {
1832 if (i == dnums.input_batch_dimension()) {
1833 is_dynamic[dnums.output_batch_dimension()] = true;
1834 } else if (i == dnums.input_feature_dimension()) {
1835 // Input feature dimension is a contracting dimension, which does not
1836 // affect the output dimension size. So we need to do nothing.
1837 } else {
1838 for (int64_t j = 0; j < dnums.output_spatial_dimensions_size(); ++j) {
1839 if (i == dnums.input_spatial_dimensions(j)) {
1840 // i is a spatial dimension, find corresponding output spatial
1841 // dimension.
1842 is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1843 }
1844 }
1845 }
1846 }
1847 if (rhs.is_dynamic_dimension(i)) {
1848 if (i == dnums.kernel_input_feature_dimension()) {
1849 // Kernel feature dimension does not affect the output dimension size.
1850 // So we need to do nothing.
1851 } else if (i == dnums.kernel_output_feature_dimension()) {
1852 return InvalidArgument(
1853 "Dynamic output feature dim on convolution kernel is not "
1854 "supported: rhs shape is %s ",
1855 rhs.ToString());
1856 } else {
1857 for (int64_t j = 0; j < dnums.kernel_spatial_dimensions_size(); ++j) {
1858 if (i == dnums.kernel_spatial_dimensions(j)) {
1859 // i is a spatial dimension, find corresponding output spatial
1860 // dimension.
1861 is_dynamic[dnums.output_spatial_dimensions(j)] = true;
1862 }
1863 }
1864 }
1865 }
1866 }
1867 TF_ASSIGN_OR_RETURN(
1868 PrimitiveType type,
1869 MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
1870 preferred_element_type));
1871 return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
1872 }
1873
InferFftShape(const Shape & in,const FftType fft_type,const absl::Span<const int64_t> fft_length)1874 /* static */ StatusOr<Shape> ShapeInference::InferFftShape(
1875 const Shape& in, const FftType fft_type,
1876 const absl::Span<const int64_t> fft_length) {
1877 const int64_t fft_rank = fft_length.size();
1878 if (fft_rank < 1 || fft_rank > 3) {
1879 return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank);
1880 }
1881 #define RET_CHECK_RANK(x) \
1882 if (x.dimensions_size() < fft_rank) { \
1883 return InvalidArgument( \
1884 "FFT of rank %d requires input of at least " \
1885 "same rank; got input of rank %d", \
1886 fft_rank, x.dimensions_size()); \
1887 }
1888 switch (fft_type) {
1889 case FFT:
1890 case IFFT:
1891 if (!primitive_util::IsComplexType(in.element_type())) {
1892 return InvalidArgument("%s requires complex input type, found %s.",
1893 FftType_Name(fft_type),
1894 PrimitiveType_Name(in.element_type()));
1895 }
1896 RET_CHECK_RANK(in);
1897 return in;
1898 case RFFT: {
1899 if (in.element_type() != F32 && in.element_type() != F64) {
1900 return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
1901 PrimitiveType_Name(in.element_type()));
1902 }
1903 RET_CHECK_RANK(in);
1904 for (int i = 0; i < fft_rank; i++) {
1905 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1906 fft_length[i]) {
1907 return InvalidArgument(
1908 "RFFT requires innermost dimensions match fft_length but "
1909 "dimension %d is %d and should be %d.",
1910 in.dimensions_size() - fft_rank + i,
1911 in.dimensions(in.dimensions_size() - fft_rank + i),
1912 fft_length[i]);
1913 }
1914 }
1915 Shape result = ShapeUtil::ChangeElementType(
1916 in, in.element_type() == F32 ? C64 : C128);
1917 // Preserve the size of zero-sized dimensions.
1918 if (fft_length[fft_rank - 1] != 0) {
1919 result.set_dimensions(result.dimensions_size() - 1,
1920 fft_length[fft_rank - 1] / 2 + 1);
1921 }
1922 return result;
1923 }
1924 case IRFFT: {
1925 if (!primitive_util::IsComplexType(in.element_type())) {
1926 return InvalidArgument("IRFFT requires complex input type, found %s.",
1927 PrimitiveType_Name(in.element_type()));
1928 }
1929 RET_CHECK_RANK(in);
1930 Shape result = ShapeUtil::ComplexComponentShape(in);
1931 for (int i = 0; i < fft_rank - 1; i++) {
1932 if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
1933 fft_length[i]) {
1934 return InvalidArgument(
1935 "IRFFT requires all but one innermost dimensions match "
1936 "fft_length, but dimension %d is %d and should be %d.",
1937 in.dimensions_size() - fft_rank + i,
1938 in.dimensions(in.dimensions_size() - fft_rank + i),
1939 fft_length[i]);
1940 }
1941 }
1942 // The size of zero-sized dimensions is preserved.
1943 if ((in.dimensions(in.dimensions_size() - 1) != 0 ||
1944 fft_length[fft_rank - 1] != 0) &&
1945 in.dimensions(in.dimensions_size() - 1) !=
1946 fft_length[fft_rank - 1] / 2 + 1) {
1947 return InvalidArgument(
1948 "IRFFT requires innermost dimension matches fft_length/2+1, but "
1949 "dimension %d is %d and should be %d.",
1950 in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1),
1951 fft_length[fft_rank - 1] / 2 + 1);
1952 }
1953 result.set_dimensions(result.dimensions_size() - 1,
1954 fft_length[fft_rank - 1]);
1955 return result;
1956 }
1957 default:
1958 LOG(FATAL) << "Unexpected fft_type: " << fft_type;
1959 }
1960 #undef RET_CHECK_RANK
1961 }
1962
InferTriangularSolveShape(const Shape & a,const Shape & b,const TriangularSolveOptions & options)1963 /* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
1964 const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
1965 if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
1966 a.element_type() != b.element_type()) {
1967 return InvalidArgument(
1968 "Expected element types in shape to be floating or complex and "
1969 "identical for TriangularSolve; got %s and %s.",
1970 PrimitiveType_Name(a.element_type()),
1971 PrimitiveType_Name(b.element_type()));
1972 }
1973 if (a.rank() < 2) {
1974 return InvalidArgument(
1975 "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
1976 a.ToString());
1977 }
1978 if (b.rank() != a.rank()) {
1979 return InvalidArgument(
1980 "Arguments to triangular solve must have equal rank; got %s and %s.",
1981 b.ToString(), a.ToString());
1982 }
1983 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
1984 return InvalidArgument(
1985 "The two minor dimensions of 'a' must have equal size, got %s.",
1986 a.ToString());
1987 }
1988 if (a.dimensions(a.rank() - 1) !=
1989 b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) {
1990 return InvalidArgument(
1991 "The shared dimension of 'a' and 'b' does not match, got shapes %s and "
1992 "%s",
1993 a.ToString(), b.ToString());
1994 }
1995 absl::Span<const int64_t> a_batch_dims(a.dimensions());
1996 absl::Span<const int64_t> b_batch_dims(b.dimensions());
1997 a_batch_dims.remove_suffix(2);
1998 b_batch_dims.remove_suffix(2);
1999 if (a_batch_dims != b_batch_dims) {
2000 return InvalidArgument(
2001 "The leading batch dimensions of the arguments to triangular solve "
2002 "must be equal; got %s and %s.",
2003 b.ToString(), a.ToString());
2004 }
2005 if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) ||
2006 options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) {
2007 return InvalidArgument(
2008 "Invalid transpose option value for triangular solve (%d).\n",
2009 options.transpose_a());
2010 }
2011 return b;
2012 }
2013
InferCholeskyShape(const Shape & a)2014 /* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
2015 const Shape& a) {
2016 if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
2017 return InvalidArgument(
2018 "Expected element type in shape to be floating or complex for "
2019 "Cholesky; got %s.",
2020 PrimitiveType_Name(a.element_type()));
2021 }
2022 if (a.rank() < 2) {
2023 return InvalidArgument(
2024 "The 'a' argument to Cholesky must have rank >= 2, got shape %s",
2025 a.ToString());
2026 }
2027 if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
2028 return InvalidArgument(
2029 "The two minor dimensions of 'a' must have equal size, got %s.",
2030 a.ToString());
2031 }
2032 return a;
2033 }
2034
InferAllGatherShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2035 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
2036 absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2037 int64_t shard_count) {
2038 TF_RET_CHECK(all_gather_dimension >= 0);
2039 TF_RET_CHECK(shard_count > 0);
2040
2041 std::vector<Shape> output_shapes;
2042 output_shapes.reserve(operand_shapes.size());
2043 for (const Shape* operand_shape : operand_shapes) {
2044 TF_RET_CHECK(all_gather_dimension < operand_shape->rank());
2045 TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather"));
2046
2047 Shape output_shape = *operand_shape;
2048 output_shape.set_dimensions(
2049 all_gather_dimension,
2050 shard_count * output_shape.dimensions(all_gather_dimension));
2051 output_shapes.push_back(output_shape);
2052 }
2053 if (output_shapes.size() == 1) {
2054 return output_shapes[0];
2055 }
2056 return ShapeUtil::MakeTupleShape(output_shapes);
2057 }
2058
InferAllGatherStartShape(absl::Span<const Shape * const> operand_shapes,int64_t all_gather_dimension,int64_t shard_count)2059 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherStartShape(
2060 absl::Span<const Shape* const> operand_shapes, int64_t all_gather_dimension,
2061 int64_t shard_count) {
2062 TF_ASSIGN_OR_RETURN(
2063 Shape ag_shape,
2064 InferAllGatherShape(operand_shapes, all_gather_dimension, shard_count));
2065 Shape input_shape;
2066 if (operand_shapes.size() == 1) {
2067 input_shape = *operand_shapes[0];
2068 } else {
2069 input_shape = ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
2070 }
2071 return ShapeUtil::MakeTupleShapeWithPtrs({&input_shape, &ag_shape});
2072 }
2073
InferAllGatherDoneShape(const Shape & all_gather_start_shape)2074 /* static */ StatusOr<Shape> ShapeInference::InferAllGatherDoneShape(
2075 const Shape& all_gather_start_shape) {
2076 return ShapeUtil::GetTupleElementShape(all_gather_start_shape, 1);
2077 }
2078
InferAllReduceShape(absl::Span<const Shape * const> operand_shapes)2079 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
2080 absl::Span<const Shape* const> operand_shapes) {
2081 for (const Shape* operand_shape : operand_shapes) {
2082 TF_RETURN_IF_ERROR(
2083 ExpectArray(*operand_shape, "operand of cross replica sum"));
2084 }
2085 if (operand_shapes.size() == 1) {
2086 return *operand_shapes[0];
2087 }
2088 return ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
2089 }
2090
InferReduceScatterShape(absl::Span<const Shape * const> operand_shapes,int64_t scatter_dimension,int64_t shard_count)2091 /* static */ StatusOr<Shape> ShapeInference::InferReduceScatterShape(
2092 absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
2093 int64_t shard_count) {
2094 TF_RET_CHECK(scatter_dimension >= 0);
2095 TF_RET_CHECK(shard_count > 0);
2096
2097 std::vector<Shape> output_shapes;
2098 output_shapes.reserve(operand_shapes.size());
2099 for (const Shape* operand_shape : operand_shapes) {
2100 TF_RET_CHECK(scatter_dimension < operand_shape->rank());
2101 TF_RETURN_IF_ERROR(
2102 ExpectArray(*operand_shape, "operand of reduce-scatter"));
2103
2104 int64_t scatter_dim_input_size =
2105 operand_shape->dimensions(scatter_dimension);
2106 if (scatter_dim_input_size % shard_count != 0) {
2107 return InvalidArgument(
2108 "ReduceScatter operand scatter dimension size %d must be "
2109 "dividable by shard_count "
2110 "%d.",
2111 scatter_dim_input_size, shard_count);
2112 }
2113
2114 Shape output_shape = *operand_shape;
2115 output_shape.set_dimensions(scatter_dimension,
2116 scatter_dim_input_size / shard_count);
2117 output_shapes.push_back(output_shape);
2118 }
2119
2120 if (output_shapes.size() == 1) {
2121 return output_shapes[0];
2122 }
2123 return ShapeUtil::MakeTupleShape(output_shapes);
2124 }
2125
InferAllReduceStartShape(absl::Span<const Shape * const> operand_shapes)2126 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceStartShape(
2127 absl::Span<const Shape* const> operand_shapes) {
2128 return InferAllReduceShape(operand_shapes);
2129 }
2130
InferAllReduceDoneShape(const Shape & operand_shape)2131 /* static */ StatusOr<Shape> ShapeInference::InferAllReduceDoneShape(
2132 const Shape& operand_shape) {
2133 // The returned value from AllReduceDone is the operand forwarded.
2134 return operand_shape;
2135 }
2136
InferAllToAllShape(const Shape & shape,int64_t split_dimension,int64_t concat_dimension,int64_t split_count)2137 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllShape(
2138 const Shape& shape, int64_t split_dimension, int64_t concat_dimension,
2139 int64_t split_count) {
2140 TF_RET_CHECK(split_count > 0);
2141 if (split_dimension >= shape.rank() || split_dimension < 0) {
2142 return InvalidArgument(
2143 "AllToAll split_dimension %d is out-of-bounds in shape %s.",
2144 split_dimension, ShapeUtil::HumanString(shape));
2145 }
2146 if (concat_dimension >= shape.rank() || concat_dimension < 0) {
2147 return InvalidArgument(
2148 "AllToAll concat_dimension %d is out-of-bounds in shape %s.",
2149 concat_dimension, ShapeUtil::HumanString(shape));
2150 }
2151 if (shape.dimensions(split_dimension) % split_count != 0) {
2152 return InvalidArgument(
2153 "AllToAll split dimension size %d must be dividable by split_count "
2154 "%d.",
2155 shape.dimensions(split_dimension), split_count);
2156 }
2157 std::vector<int64_t> new_dimensions(shape.dimensions().begin(),
2158 shape.dimensions().end());
2159 new_dimensions[split_dimension] /= split_count;
2160 new_dimensions[concat_dimension] *= split_count;
2161 return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
2162 }
2163
InferAllToAllTupleShape(absl::Span<const Shape * const> operand_shapes)2164 /* static */ StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
2165 absl::Span<const Shape* const> operand_shapes) {
2166 // An Alltoall HLO instruction receives N operands (with the same shape) and
2167 // returns a tuple that contains N array shapes.
2168 TF_RET_CHECK(!operand_shapes.empty());
2169 for (int i = 0; i < operand_shapes.size(); i++) {
2170 if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
2171 *operand_shapes[i])) {
2172 return InvalidArgument(
2173 "HLO all-to-all has operands with different shapes: the 0th "
2174 "operand shape %s, but the %dth operand has shape %s.",
2175 ShapeUtil::HumanString(*operand_shapes[0]), i,
2176 ShapeUtil::HumanString(*operand_shapes[i]));
2177 }
2178 }
2179
2180 return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes);
2181 }
2182
InferCollectivePermuteShape(absl::Span<const Shape * const> operand_shapes)2183 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape(
2184 absl::Span<const Shape* const> operand_shapes) {
2185 if (operand_shapes.size() == 1) {
2186 TF_RETURN_IF_ERROR(
2187 ExpectArray(*(operand_shapes[0]), "operand of collective-permute"));
2188 return *(operand_shapes[0]);
2189 } else {
2190 TF_RET_CHECK(operand_shapes.size() == 4);
2191 return *(operand_shapes[1]);
2192 }
2193 }
2194
InferCollectivePermuteStartShape(absl::Span<const Shape * const> operand_shapes)2195 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteStartShape(
2196 absl::Span<const Shape* const> operand_shapes) {
2197 const Shape u32_scalar = ShapeUtil::MakeShape(U32, {});
2198 if (operand_shapes.size() == 1) {
2199 TF_RETURN_IF_ERROR(ExpectArray(*(operand_shapes[0]),
2200 "operand of collective-permute-start"));
2201 return ShapeUtil::MakeTupleShapeWithPtrs(
2202 {operand_shapes[0], operand_shapes[0], &u32_scalar, &u32_scalar});
2203 } else {
2204 TF_RET_CHECK(operand_shapes.size() == 4);
2205 return ShapeUtil::MakeTupleShapeWithPtrs(
2206 {operand_shapes[0], operand_shapes[1], &u32_scalar, &u32_scalar});
2207 }
2208 }
2209
InferCollectivePermuteDoneShape(const Shape & operand_shape)2210 /* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteDoneShape(
2211 const Shape& operand_shape) {
2212 TF_RET_CHECK(operand_shape.IsTuple());
2213 return ShapeUtil::GetTupleElementShape(operand_shape, 1);
2214 }
2215
InferReduceShape(absl::Span<const Shape * const> arg_shapes,absl::Span<const int64_t> dimensions_to_reduce,const ProgramShape & to_apply)2216 /* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
2217 absl::Span<const Shape* const> arg_shapes,
2218 absl::Span<const int64_t> dimensions_to_reduce,
2219 const ProgramShape& to_apply) {
2220 if (arg_shapes.empty()) {
2221 return InvalidArgument("Reduce must have at least 2 arguments, has 0");
2222 }
2223 if (arg_shapes.size() % 2) {
2224 return InvalidArgument(
2225 "Reduce must have an even number of arguments, has %lu",
2226 arg_shapes.size());
2227 }
2228 int64_t num_reduced_args = arg_shapes.size() / 2;
2229 auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
2230 // Check that all of the reduced tensors have the same dimensions. The element
2231 // types may be different.
2232 for (int64_t i = 1; i < num_reduced_args; ++i) {
2233 if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
2234 return InvalidArgument(
2235 "All reduced tensors must have the same dimension. Tensor 0 has "
2236 "shape %s, Tensor %d has shape %s",
2237 ShapeUtil::HumanString(*reduced_args[0]), i,
2238 ShapeUtil::HumanString(*reduced_args[i]));
2239 }
2240 }
2241 // Check that the dimensions to reduce are in-bounds for the given shape.
2242 // We've already verified all reduced tensors have the same dimensions, so it
2243 // doesn't matter which one we choose.
2244 const Shape& arg = *reduced_args[0];
2245 for (int64_t dimension : dimensions_to_reduce) {
2246 if (dimension >= arg.rank() || dimension < 0) {
2247 return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.",
2248 dimension, ShapeUtil::HumanString(arg));
2249 }
2250 }
2251
2252 auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size());
2253 std::vector<PrimitiveType> element_types;
2254 element_types.reserve(reduced_args.size());
2255 for (const Shape* arg : reduced_args) {
2256 element_types.push_back(arg->element_type());
2257 }
2258 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
2259 num_reduced_args));
2260
2261 absl::flat_hash_set<int64_t> dimensions_to_reduce_set;
2262 for (int64_t dim_to_reduce : dimensions_to_reduce) {
2263 if (!dimensions_to_reduce_set.insert(dim_to_reduce).second) {
2264 return InvalidArgument("Duplicate reduction dimension: %d",
2265 dim_to_reduce);
2266 }
2267 }
2268
2269 std::vector<int64_t> new_dimensions;
2270 std::vector<bool> new_is_dynamic;
2271 for (int i = 0; i < arg.rank(); ++i) {
2272 if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
2273 new_dimensions.push_back(arg.dimensions(i));
2274 new_is_dynamic.push_back(arg.is_dynamic_dimension(i));
2275 }
2276 }
2277
2278 if (ShapeUtil::IsScalar(to_apply.result())) {
2279 return ShapeUtil::MakeShape(to_apply.result().element_type(),
2280 new_dimensions, new_is_dynamic);
2281 } else {
2282 std::vector<Shape> result_subshapes;
2283 const auto& tuple_shapes = to_apply.result().tuple_shapes();
2284 result_subshapes.reserve(tuple_shapes.size());
2285 for (const Shape& subshape : tuple_shapes) {
2286 result_subshapes.push_back(ShapeUtil::MakeShape(
2287 subshape.element_type(), new_dimensions, new_is_dynamic));
2288 }
2289 return ShapeUtil::MakeTupleShape(result_subshapes);
2290 }
2291 }
2292
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window,const ProgramShape & to_apply_shape)2293 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2294 const Shape& operand_shape, const Shape& init_value_shape,
2295 const Window& window, const ProgramShape& to_apply_shape) {
2296 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
2297 {operand_shape.element_type()},
2298 /*inputs=*/1));
2299 return InferReduceWindowShape(operand_shape, init_value_shape, window);
2300 }
2301
InferReduceWindowShape(absl::Span<const Shape * const> operands,absl::Span<const Shape * const> init_values,const Window & window,const ProgramShape & to_apply_shape)2302 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2303 absl::Span<const Shape* const> operands,
2304 absl::Span<const Shape* const> init_values, const Window& window,
2305 const ProgramShape& to_apply_shape) {
2306 auto number_of_input = operands.size();
2307 // Check that all of the reduced tensors have the same dimensions. The element
2308 // types may be different.
2309 for (int64_t i = 1; i < number_of_input; ++i) {
2310 if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
2311 return InvalidArgument(
2312 "All reduced tensors must have the same dimension. Tensor 0 has "
2313 "shape %s, Tensor %d has shape %s",
2314 ShapeUtil::HumanString(*operands[0]), i,
2315 ShapeUtil::HumanString(*operands[i]));
2316 }
2317 }
2318 std::vector<PrimitiveType> operand_element_type_vec;
2319 operand_element_type_vec.reserve(operands.size());
2320 for (const Shape* s : operands) {
2321 operand_element_type_vec.push_back(s->element_type());
2322 }
2323 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
2324 operand_element_type_vec,
2325 /*inputs=*/number_of_input));
2326 std::vector<Shape> output_shape_vec;
2327 const size_t n = operands.size();
2328 output_shape_vec.reserve(n);
2329 for (size_t i = 0; i < operands.size(); ++i) {
2330 TF_ASSIGN_OR_RETURN(
2331 auto cur_output_shape,
2332 InferReduceWindowShape(*operands[i], *init_values[i], window));
2333 output_shape_vec.push_back(cur_output_shape);
2334 }
2335 if (ShapeUtil::IsScalar(to_apply_shape.result())) {
2336 CHECK_EQ(output_shape_vec.size(), 1);
2337 return output_shape_vec[0];
2338 } else {
2339 return ShapeUtil::MakeTupleShape(output_shape_vec);
2340 }
2341 }
2342
InferReduceWindowShape(const Shape & operand_shape,const Shape & init_value_shape,const Window & window)2343 /* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
2344 const Shape& operand_shape, const Shape& init_value_shape,
2345 const Window& window) {
2346 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
2347 return InferWindowOutputShape(operand_shape, window,
2348 init_value_shape.element_type());
2349 }
2350
InferSelectAndScatterShape(const Shape & operand_shape,const ProgramShape & select_shape,const Window & window,const Shape & source_shape,const Shape & init_value_shape,const ProgramShape & scatter_shape)2351 /* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
2352 const Shape& operand_shape, const ProgramShape& select_shape,
2353 const Window& window, const Shape& source_shape,
2354 const Shape& init_value_shape, const ProgramShape& scatter_shape) {
2355 TF_RETURN_IF_ERROR(
2356 ExpectArray(operand_shape, "operand of select-and-scatter"));
2357
2358 // Check if the select function has a proper shape of (T,T) -> PRED.
2359 if (select_shape.parameters_size() != 2) {
2360 return InvalidArgument(
2361 "Select function must take 2 parameters, but "
2362 "takes %d parameter(s).",
2363 select_shape.parameters_size());
2364 }
2365 const Shape& select_result_shape = select_shape.result();
2366 if (!ShapeUtil::Compatible(select_result_shape,
2367 ShapeUtil::MakeShape(PRED, {}))) {
2368 return InvalidArgument("Select function must have rank-0 PRED result.");
2369 }
2370 const Shape& operand_element_shape =
2371 ShapeUtil::MakeShape(operand_shape.element_type(), {});
2372 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2373 select_shape.parameters(0))) {
2374 return InvalidArgument(
2375 "Select function's first parameter shape currently must "
2376 "match the operand element shape, but got %s vs %s.",
2377 ShapeUtil::HumanString(select_shape.parameters(0)),
2378 ShapeUtil::HumanString(operand_element_shape));
2379 }
2380 if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
2381 select_shape.parameters(1))) {
2382 return InvalidArgument(
2383 "Select function's second parameter shape currently must "
2384 "match the operand element shape, but got %s vs %s.",
2385 ShapeUtil::HumanString(select_shape.parameters(1)),
2386 ShapeUtil::HumanString(operand_element_shape));
2387 }
2388
2389 // Check if the scatter function has a proper shape as a reduction.
2390 TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
2391 {source_shape.element_type()},
2392 /*inputs=*/1));
2393
2394 // Check if the result shape of window operation matches the source shape.
2395 TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
2396 InferWindowOutputShape(operand_shape, window,
2397 operand_shape.element_type()));
2398 if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
2399 window_result_shape)) {
2400 return InvalidArgument(
2401 "Source shape does not match the shape of window-reduced operand: "
2402 "source(%s), window-reduced operand(%s).",
2403 ShapeUtil::HumanString(source_shape),
2404 ShapeUtil::HumanString(window_result_shape));
2405 }
2406
2407 return operand_shape;
2408 }
2409
InferGetDimensionSizeShape(const Shape & shape,int64_t dimension)2410 /* static */ StatusOr<Shape> ShapeInference::InferGetDimensionSizeShape(
2411 const Shape& shape, int64_t dimension) {
2412 if (dimension < 0 || dimension >= shape.rank()) {
2413 return InvalidArgument("GetDimensionSize dimension out of bounds: %d.",
2414 dimension);
2415 }
2416
2417 // TODO(b/119580730): Remove this restriction when very large dimension size
2418 // is needed.
2419 if (shape.dimensions(dimension) > std::numeric_limits<int32_t>::max()) {
2420 return InvalidArgument(
2421 "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2422 "INT_MAX limit.",
2423 ShapeUtil::HumanString(shape), dimension);
2424 }
2425
2426 return ShapeUtil::MakeShape(S32, {});
2427 }
2428
InferSetDimensionSizeShape(const Shape & shape,const Shape & val_shape,int64_t dimension)2429 /* static */ StatusOr<Shape> ShapeInference::InferSetDimensionSizeShape(
2430 const Shape& shape, const Shape& val_shape, int64_t dimension) {
2431 if (dimension < 0 || dimension >= shape.rank()) {
2432 return InvalidArgument("SetDimensionSize dimension out of bounds: %d.",
2433 dimension);
2434 }
2435
2436 if (val_shape.rank() != 0 || val_shape.element_type() != S32) {
2437 return InvalidArgument(
2438 "SetDimensionSize's value has to be S32 scalar, got %s",
2439 val_shape.ToString());
2440 }
2441 // TODO(b/119580730): Remove this restriction when very large dimension size
2442 // is needed.
2443 if (shape.dimensions(dimension) > std::numeric_limits<int32_t>::max()) {
2444 return InvalidArgument(
2445 "SetDimensionSize's input shape is %s, the %dth dimension exceeds the "
2446 "INT_MAX limit.",
2447 ShapeUtil::HumanString(shape), dimension);
2448 }
2449
2450 Shape result = shape;
2451 result.set_dynamic_dimension(dimension, true);
2452 return result;
2453 }
2454
InferWindowFromDimensions(absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,absl::Span<const std::pair<int64_t,int64_t>> padding,absl::Span<const int64_t> lhs_dilation,absl::Span<const int64_t> rhs_dilation,std::optional<std::vector<bool>> window_reversal)2455 /* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
2456 absl::Span<const int64_t> window_dimensions,
2457 absl::Span<const int64_t> window_strides,
2458 absl::Span<const std::pair<int64_t, int64_t>> padding,
2459 absl::Span<const int64_t> lhs_dilation,
2460 absl::Span<const int64_t> rhs_dilation,
2461 std::optional<std::vector<bool>> window_reversal) {
2462 const auto verify_size = [&](const size_t x, const char* x_name) {
2463 if (x == 0 || x == window_dimensions.size()) {
2464 return OkStatus();
2465 } else {
2466 return InvalidArgument(
2467 "%s", absl::StrCat(
2468 "Window has different number of window dimensions than of ",
2469 x_name,
2470 "\nNumber of window dimensions: ", window_dimensions.size(),
2471 "\nNumber of ", x_name, ": ", x, "\n"));
2472 }
2473 };
2474 TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
2475 TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
2476 TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
2477 TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
2478 if (window_reversal.has_value()) {
2479 TF_RETURN_IF_ERROR(verify_size(window_reversal->size(), "window reversal"));
2480 }
2481
2482 Window window;
2483 for (size_t i = 0; i < window_dimensions.size(); i++) {
2484 auto dim = window.add_dimensions();
2485 dim->set_size(window_dimensions[i]);
2486 if (!window_strides.empty()) {
2487 dim->set_stride(window_strides[i]);
2488 } else {
2489 dim->set_stride(1);
2490 }
2491 if (!padding.empty()) {
2492 dim->set_padding_low(padding[i].first);
2493 dim->set_padding_high(padding[i].second);
2494 } else {
2495 dim->set_padding_low(0);
2496 dim->set_padding_high(0);
2497 }
2498 if (!lhs_dilation.empty()) {
2499 dim->set_base_dilation(lhs_dilation[i]);
2500 } else {
2501 dim->set_base_dilation(1);
2502 }
2503 if (!rhs_dilation.empty()) {
2504 dim->set_window_dilation(rhs_dilation[i]);
2505 } else {
2506 dim->set_window_dilation(1);
2507 }
2508 if (window_reversal.has_value() && !window_reversal->empty()) {
2509 dim->set_window_reversal(window_reversal->at(i));
2510 } else {
2511 dim->set_window_reversal(false);
2512 }
2513 }
2514 return window;
2515 }
2516
InferSliceShape(const Shape & arg,absl::Span<const int64_t> starts,absl::Span<const int64_t> limits,absl::Span<const int64_t> strides)2517 /* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
2518 const Shape& arg, absl::Span<const int64_t> starts,
2519 absl::Span<const int64_t> limits, absl::Span<const int64_t> strides) {
2520 auto error = [&](const std::string& message) {
2521 return InvalidArgument(
2522 "%s in slice operation; argument shape: %s; starts: {%s}; limits: "
2523 "{%s}; strides: {%s}.",
2524 message, ShapeUtil::HumanString(arg), StrJoin(starts, ","),
2525 StrJoin(limits, ","), StrJoin(strides, ","));
2526 };
2527 TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
2528 VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}",
2529 ShapeUtil::HumanString(arg), StrJoin(starts, ", "),
2530 StrJoin(limits, ", "));
2531
2532 if (starts.size() != limits.size()) {
2533 return error(StrFormat("slice start and limit sizes differ: %u vs %u",
2534 starts.size(), limits.size()));
2535 }
2536
2537 if (starts.size() != strides.size()) {
2538 return error(StrFormat("slice start and strides sizes differ: %u vs %u",
2539 starts.size(), strides.size()));
2540 }
2541
2542 if (starts.size() != arg.rank()) {
2543 return InvalidArgument(
2544 "Slice index count does not match argument rank: %u vs %d.",
2545 starts.size(), arg.rank());
2546 }
2547
2548 std::vector<int64_t> sizes;
2549 const auto starts_size = starts.size();
2550 sizes.reserve(starts_size);
2551 for (int64_t dimension = 0; dimension < starts_size; ++dimension) {
2552 int64_t start_index = starts[dimension];
2553 int64_t limit_index = limits[dimension];
2554 int64_t stride = strides[dimension];
2555 if (start_index < 0) {
2556 return InvalidArgument("Negative start index to slice: %d.", start_index);
2557 }
2558 if (limit_index > arg.dimensions(dimension)) {
2559 return error(
2560 StrFormat("limit index (%d) must be less than or equal to dimension "
2561 "size (%d)",
2562 limit_index, arg.dimensions(dimension)));
2563 }
2564 VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index);
2565 VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index);
2566 if (start_index > limit_index) {
2567 return error(
2568 StrFormat("limit index (%d) must be greater or equal to "
2569 "start index (%d) in slice with positive stride",
2570 limit_index, start_index));
2571 }
2572 if (stride <= 0) {
2573 return InvalidArgument("Stride (%d) must be positive.", stride);
2574 }
2575 sizes.push_back((limit_index - start_index + stride - 1) / stride);
2576 }
2577
2578 std::vector<bool> is_dynamic(arg.rank());
2579 for (int64_t i = 0; i < arg.dimensions_size(); ++i) {
2580 // Slicing 1 out of a dynamic dimension eliminates the dynamic dimension.
2581 if (sizes[i] == 1) {
2582 continue;
2583 }
2584 is_dynamic[i] = arg.is_dynamic_dimension(i);
2585 }
2586
2587 return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
2588 }
2589
InferDynamicSliceShape(const Shape & operand_shape,absl::Span<const Shape> start_index_shapes,absl::Span<const int64_t> slice_sizes,bool allow_scalar_indices)2590 /* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
2591 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
2592 absl::Span<const int64_t> slice_sizes, bool allow_scalar_indices) {
2593 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
2594 auto number_of_indices = start_index_shapes.size();
2595 // TODO(b/118437727): Remove this path.
2596 if (!allow_scalar_indices ||
2597 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2598 if (number_of_indices != 1) {
2599 return InvalidArgument(
2600 "Dynamic slice should have exactly 1 index operand, has %d.",
2601 number_of_indices);
2602 }
2603
2604 const Shape& start_indices_shape = start_index_shapes[0];
2605 VLOG(2) << StrFormat(
2606 "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
2607 ShapeUtil::HumanString(operand_shape),
2608 ShapeUtil::HumanString(start_indices_shape),
2609 StrJoin(slice_sizes, ", "));
2610
2611 TF_RETURN_IF_ERROR(
2612 ExpectArray(start_indices_shape, "start indices of dynamic slice"));
2613
2614 if (start_indices_shape.rank() != 1) {
2615 return InvalidArgument(
2616 "Dynamic slice start indices of rank %d must be rank1.",
2617 start_indices_shape.rank());
2618 }
2619
2620 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2621 return InvalidArgument(
2622 "Dynamic slice start indices must be of integral type.");
2623 }
2624
2625 const int64_t start_num_dims = start_indices_shape.dimensions(0);
2626 if (operand_shape.rank() != start_num_dims) {
2627 return InvalidArgument(
2628 "Dynamic slice start number of dimensions %d (%s) must match rank "
2629 "%d of slice input (%s).",
2630 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2631 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2632 }
2633 } else {
2634 VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}",
2635 ShapeUtil::HumanString(operand_shape),
2636 StrJoin(slice_sizes, ", "));
2637
2638 if (operand_shape.rank() != number_of_indices) {
2639 return InvalidArgument(
2640 "Dynamic slice start number of dimensions %d must match rank "
2641 "%d of slice input (%s).",
2642 number_of_indices, operand_shape.rank(),
2643 ShapeUtil::HumanString(operand_shape));
2644 }
2645
2646 if (number_of_indices > 0) {
2647 const Shape& first_index_shape = start_index_shapes[0];
2648 if (!ShapeUtil::IsScalar(first_index_shape)) {
2649 return InvalidArgument("Dynamic slice indices must be scalar, not %s.",
2650 ShapeUtil::HumanString(first_index_shape));
2651 }
2652 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2653 return InvalidArgument(
2654 "Dynamic slice start indices must be of integral type.");
2655 }
2656 for (const Shape& index_shape : start_index_shapes) {
2657 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2658 return InvalidArgument(
2659 "Dynamic slice start indices must all have the same shape, got "
2660 "mismatching indices with shapes %s and %s.",
2661 ShapeUtil::HumanString(first_index_shape),
2662 ShapeUtil::HumanString(index_shape));
2663 }
2664 }
2665 }
2666 }
2667
2668 if (slice_sizes.size() != operand_shape.rank()) {
2669 return InvalidArgument(
2670 "Dynamic slice index count does not match argument rank: %u vs %d.",
2671 slice_sizes.size(), operand_shape.rank());
2672 }
2673
2674 for (int64_t dim = 0; dim < slice_sizes.size(); ++dim) {
2675 const int64_t input_dim_size = operand_shape.dimensions(dim);
2676 const int64_t slice_dim_size = slice_sizes[dim];
2677 if (slice_dim_size < 0) {
2678 return InvalidArgument("Negative size index to dynamic slice: %d.",
2679 slice_dim_size);
2680 }
2681 if (slice_dim_size > input_dim_size) {
2682 return InvalidArgument(
2683 "Slice dim size %d greater than dynamic slice dimension: %d.",
2684 slice_dim_size, input_dim_size);
2685 }
2686 VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size);
2687 }
2688
2689 return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
2690 }
2691
InferDynamicUpdateSliceShape(const Shape & operand_shape,const Shape & update_shape,absl::Span<const Shape> start_index_shapes,bool allow_scalar_indices)2692 /* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
2693 const Shape& operand_shape, const Shape& update_shape,
2694 absl::Span<const Shape> start_index_shapes, bool allow_scalar_indices) {
2695 TF_RETURN_IF_ERROR(
2696 ExpectArray(operand_shape, "operand of dynamic update slice"));
2697 TF_RETURN_IF_ERROR(
2698 ExpectArray(update_shape, "update of dynamic update slice"));
2699
2700 auto number_of_indices = start_index_shapes.size();
2701 // TODO(b/118437727): Remove this path.
2702 if (!allow_scalar_indices ||
2703 (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) {
2704 if (number_of_indices != 1) {
2705 return InvalidArgument(
2706 "Dynamic update slice should have exactly 1 index operand, has %d.",
2707 number_of_indices);
2708 }
2709 const Shape& start_indices_shape = start_index_shapes[0];
2710 TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
2711 "start indices of dynamic update slice"));
2712
2713 VLOG(2) << StrFormat(
2714 "updating slice of shape %s at dynamic start_indices %s with update "
2715 "shape %s",
2716 ShapeUtil::HumanString(operand_shape),
2717 ShapeUtil::HumanString(start_indices_shape),
2718 ShapeUtil::HumanString(update_shape));
2719
2720 if (start_indices_shape.rank() != 1) {
2721 return InvalidArgument(
2722 "Dynamic update slice start indices of rank %d must be rank1.",
2723 start_indices_shape.rank());
2724 }
2725
2726 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
2727 return InvalidArgument(
2728 "Dynamic update slice start indices must be of integral type.");
2729 }
2730
2731 const int64_t start_num_dims = start_indices_shape.dimensions(0);
2732 if (operand_shape.rank() != start_num_dims) {
2733 return InvalidArgument(
2734 "Dynamic update slice start number of dimensions %d (%s) must match "
2735 "rank %d of slice input (%s).",
2736 start_num_dims, ShapeUtil::HumanString(start_indices_shape),
2737 operand_shape.rank(), ShapeUtil::HumanString(operand_shape));
2738 }
2739 } else {
2740 VLOG(2) << StrFormat("updating slice of shape %s with update shape %s",
2741 ShapeUtil::HumanString(operand_shape),
2742 ShapeUtil::HumanString(update_shape));
2743
2744 if (operand_shape.rank() != number_of_indices) {
2745 return InvalidArgument(
2746 "Dynamic update slice start number of dimensions %d must match "
2747 "rank %d of slice input (%s).",
2748 number_of_indices, operand_shape.rank(),
2749 ShapeUtil::HumanString(operand_shape));
2750 }
2751
2752 if (number_of_indices > 0) {
2753 const Shape& first_index_shape = start_index_shapes[0];
2754 if (!ShapeUtil::IsScalar(first_index_shape)) {
2755 return InvalidArgument(
2756 "Dynamic update slice indices must be scalar, not %s.",
2757 ShapeUtil::HumanString(first_index_shape));
2758 }
2759 if (!ShapeUtil::ElementIsIntegral(first_index_shape)) {
2760 return InvalidArgument(
2761 "Dynamic update slice start indices must be of integral type.");
2762 }
2763 for (const Shape& index_shape : start_index_shapes) {
2764 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) {
2765 return InvalidArgument(
2766 "Dynamic update slice start indices must all have the same "
2767 "shape, got mismatching indices with shapes %s and %s.",
2768 ShapeUtil::HumanString(first_index_shape),
2769 ShapeUtil::HumanString(index_shape));
2770 }
2771 }
2772 }
2773 }
2774
2775 if (update_shape.rank() != operand_shape.rank()) {
2776 return InvalidArgument(
2777 "Dynamic update slice update rank does not match argument rank: "
2778 "%d vs %d.",
2779 update_shape.rank(), operand_shape.rank());
2780 }
2781
2782 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
2783 update_shape)) {
2784 return InvalidArgument(
2785 "Dynamic update slice update element type does not match argument. "
2786 "operand.element_type: %s vs update.element_type: %s.",
2787 PrimitiveType_Name(operand_shape.element_type()),
2788 PrimitiveType_Name(update_shape.element_type()));
2789 }
2790
2791 for (int64_t dim = 0; dim < operand_shape.rank(); ++dim) {
2792 const int64_t input_dim_size = operand_shape.dimensions(dim);
2793 const int64_t update_dim_size = update_shape.dimensions(dim);
2794 if (update_dim_size < 0) {
2795 return InvalidArgument(
2796 "Size index %d to dynamic update slice must be >= 0.",
2797 update_dim_size);
2798 }
2799 if (update_dim_size > input_dim_size) {
2800 return InvalidArgument(
2801 "Update dim size %d greater than dynamic slice dimension: %d.",
2802 update_dim_size, input_dim_size);
2803 }
2804 VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size);
2805 }
2806
2807 auto result_shape = operand_shape;
2808
2809 // If any of the operand shape is dynamic, the result dimension is also
2810 // dynamic.
2811 // If update shape is dynamic, only propagate dynamic dimension to result if
2812 // the update is a full update (update_shape[i] == operand_shape[i]).
2813 for (int64_t i = 0; i < update_shape.rank(); ++i) {
2814 if (operand_shape.is_dynamic_dimension(i)) {
2815 result_shape.set_dynamic_dimension(i, true);
2816 }
2817
2818 if (update_shape.is_dynamic_dimension(i) &&
2819 update_shape.dimensions(i) == operand_shape.dimensions(i)) {
2820 // When update/replace a full dimension, propagate dynamic dimension to
2821 // the result.
2822 result_shape.set_dynamic_dimension(i, true);
2823 }
2824 }
2825
2826 return result_shape;
2827 }
2828
InferReverseShape(const Shape & operand_shape,absl::Span<const int64_t> dimensions)2829 /*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
2830 const Shape& operand_shape, absl::Span<const int64_t> dimensions) {
2831 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
2832 if (!AllUnique(dimensions)) {
2833 return InvalidArgument("a dimension number is duplicated in reverse");
2834 }
2835 for (int64_t dimension : dimensions) {
2836 if (dimension >= operand_shape.rank() || dimension < 0) {
2837 return InvalidArgument(
2838 "One of the reverse dimensions (%d) is out-of-bounds in shape %s.",
2839 dimension, ShapeUtil::HumanString(operand_shape));
2840 }
2841 }
2842 return operand_shape;
2843 }
2844
InferGetTupleElementShape(const Shape & arg,int64_t index)2845 /* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
2846 const Shape& arg, int64_t index) {
2847 if (!arg.IsTuple()) {
2848 return InvalidArgument(
2849 "Cannot infer shape: attempting to index into non-tuple: %s.",
2850 ShapeUtil::HumanString(arg));
2851 }
2852
2853 if (index < 0 || index >= arg.tuple_shapes_size()) {
2854 return InvalidArgument(
2855 "Cannot infer shape: attempt to index out of tuple bounds: %d "
2856 ">= %d in shape %s.",
2857 index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg));
2858 }
2859
2860 return arg.tuple_shapes(index);
2861 }
2862
InferWhileShape(const ProgramShape & condition,const ProgramShape & body,const Shape & init)2863 /* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
2864 const ProgramShape& condition, const ProgramShape& body,
2865 const Shape& init) {
2866 // Check the number of parameters for given computations.
2867 if (condition.parameters_size() != 1) {
2868 return InvalidArgument("Condition must take 1 arguments; got %d.",
2869 condition.parameters_size());
2870 }
2871 if (body.parameters_size() != 1) {
2872 return InvalidArgument("Body must take 1 arguments; got %d.",
2873 body.parameters_size());
2874 }
2875
2876 auto shape_string = [&]() {
2877 return StrFormat(
2878 "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition),
2879 ShapeUtil::HumanString(body), ShapeUtil::HumanString(init));
2880 };
2881
2882 // Check the shapes of computation parameters and return types.
2883 if (!ShapeUtil::Compatible(condition.result(),
2884 ShapeUtil::MakeShape(PRED, {}))) {
2885 return InvalidArgument("Condition must return a boolean; got %s.",
2886 shape_string());
2887 }
2888 if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
2889 !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
2890 !ShapeUtil::Compatible(body.result(), init)) {
2891 return InvalidArgument(
2892 "The parameter of condition and body, the result of the body, and init "
2893 "must all have the same shape; got %s.",
2894 shape_string());
2895 }
2896
2897 return init;
2898 }
2899
InferConditionalShape(const Shape & branch_index,absl::Span<const ProgramShape> branch_computations,absl::Span<const Shape> branch_operands)2900 /* static */ StatusOr<Shape> ShapeInference::InferConditionalShape(
2901 const Shape& branch_index,
2902 absl::Span<const ProgramShape> branch_computations,
2903 absl::Span<const Shape> branch_operands) {
2904 if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) &&
2905 !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) {
2906 return InvalidArgument("branch_index must be bool or int32_t; got %s.",
2907 ShapeUtil::HumanString(branch_index));
2908 }
2909 if (branch_index.element_type() == PRED) {
2910 TF_RET_CHECK(2 == branch_computations.size());
2911 } else {
2912 TF_RET_CHECK(!branch_computations.empty());
2913 }
2914 TF_RET_CHECK(branch_computations.size() == branch_operands.size());
2915 Shape result = branch_computations[0].result();
2916 for (int j = 0; j < branch_computations.size(); ++j) {
2917 if (branch_computations[j].parameters_size() != 1) {
2918 return InvalidArgument(
2919 "branch computation %d must take 1 argument; got %d.", j,
2920 branch_computations[j].parameters_size());
2921 }
2922 if (!ShapeUtil::Compatible(branch_computations[j].parameters(0),
2923 branch_operands[j])) {
2924 auto shape_string = [&]() {
2925 return StrFormat("operand: %s; computation: %s",
2926 ShapeUtil::HumanString(branch_operands[j]),
2927 ShapeUtil::HumanString(branch_computations[j]));
2928 };
2929 return InvalidArgument(
2930 "branch operand %d must match the shape of the only parameter of "
2931 "branch computation %d: got %s.",
2932 j, j, shape_string());
2933 }
2934
2935 if (!ShapeUtil::Compatible(branch_computations[0].result(),
2936 branch_computations[j].result())) {
2937 auto shape_string = [&]() {
2938 return StrFormat(
2939 "branch 0 computation result: %s; branch %d computation result: %s",
2940 ShapeUtil::HumanString(branch_computations[0].result()), j,
2941 ShapeUtil::HumanString(branch_computations[j].result()));
2942 };
2943 return InvalidArgument(
2944 "the result of branch 0 computation and branch %d computation must "
2945 "have the same shape: got %s.",
2946 j, shape_string());
2947 }
2948 }
2949 // For each subshape, If any of the branch is dynamic, we say result is
2950 // dynamic:
2951 //
2952 // true_branch (s32[<=4])
2953 // false_branch (s32[4])
2954 //
2955 // Result is s32[<=4].
2956 ShapeUtil::ForEachMutableSubshape(
2957 &result, [&](Shape* subshape, const ShapeIndex& index) {
2958 if (!subshape->IsArray()) {
2959 return;
2960 }
2961 for (int j = 0; j < branch_computations.size(); ++j) {
2962 auto branch_subshape =
2963 ShapeUtil::GetSubshape(branch_computations[j].result(), index);
2964 for (int64_t i = 0; i < branch_subshape.rank(); ++i) {
2965 if (branch_subshape.is_dynamic_dimension(i)) {
2966 subshape->set_dynamic_dimension(i, true);
2967 }
2968 }
2969 }
2970 });
2971
2972 return result;
2973 }
2974
InferBroadcastShape(const Shape & operand,absl::Span<const int64_t> broadcast_sizes)2975 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
2976 const Shape& operand, absl::Span<const int64_t> broadcast_sizes) {
2977 TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
2978 for (int64_t size : broadcast_sizes) {
2979 if (size < 0) {
2980 return InvalidArgument("Broadcast with negative dimension size %d.",
2981 size);
2982 }
2983 }
2984
2985 std::vector<int64_t> dimensions(operand.dimensions_size() +
2986 broadcast_sizes.size());
2987 std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
2988 std::copy(operand.dimensions().begin(), operand.dimensions().end(),
2989 dimensions.begin() + broadcast_sizes.size());
2990
2991 Shape result = ShapeUtil::MakeShape(operand.element_type(), dimensions);
2992 for (int64_t i = 0; i < operand.dimensions_size(); ++i) {
2993 result.set_dynamic_dimension(broadcast_sizes.size() + i,
2994 operand.is_dynamic_dimension(i));
2995 }
2996 return result;
2997 }
2998
InferBroadcastShape(const Shape & operand_shape,const Shape & output_shape,absl::Span<const int64_t> broadcast_dimensions)2999 /* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
3000 const Shape& operand_shape, const Shape& output_shape,
3001 absl::Span<const int64_t> broadcast_dimensions) {
3002 TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast"));
3003 TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast"));
3004 const int64_t operand_rank = operand_shape.rank();
3005 const int64_t output_rank = output_shape.rank();
3006 if (operand_rank > output_rank) {
3007 return InvalidArgument(
3008 "InDim style broadcast must be to an equal or higher ranked shape; "
3009 "operand rank: %lld; output rank: %lld",
3010 operand_rank, output_rank);
3011 }
3012 if (operand_rank != broadcast_dimensions.size()) {
3013 return InvalidArgument(
3014 "Size of broadcast_dimensions has to match operand's rank; operand "
3015 "rank: %lld, size of broadcast_dimensions %u.",
3016 operand_rank, broadcast_dimensions.size());
3017 }
3018 for (int64_t i = 0; i < operand_rank; i++) {
3019 if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) {
3020 return InvalidArgument("Broadcast dimension %lld is out of bound",
3021 broadcast_dimensions[i]);
3022 }
3023 if (operand_shape.dimensions(i) !=
3024 output_shape.dimensions(broadcast_dimensions[i]) &&
3025 operand_shape.dimensions(i) != 1) {
3026 return InvalidArgument(
3027 "Input dimension should be either 1 or equal to the output dimension "
3028 "it is broadcasting into; the %lldth operand dimension is %lld, the "
3029 "%lldth output dimension is %lld.",
3030 i, operand_shape.dimensions(i), broadcast_dimensions[i],
3031 output_shape.dimensions(broadcast_dimensions[i]));
3032 }
3033 if (operand_shape.is_dynamic_dimension(i) !=
3034 output_shape.is_dynamic_dimension(broadcast_dimensions[i])) {
3035 return InvalidArgument(
3036 "Broadcast input and output dynamism mismatch: %s and %s",
3037 operand_shape.ToString(), output_shape.ToString());
3038 }
3039 // Make sure the broadcast dimensions are listed in a strictly increasing
3040 // order.
3041 if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) {
3042 return InvalidArgument(
3043 "Broadcast dimensions order is wrong: %d comes after %d.",
3044 broadcast_dimensions[i], broadcast_dimensions.at(i - 1));
3045 }
3046 }
3047
3048 return output_shape;
3049 }
3050
InferDynamicReshapeShape(const Shape & operand,absl::Span<const Shape * const> dim_size_shapes,absl::Span<const int64_t> new_size_bounds,const std::vector<bool> & dims_are_dynamic)3051 /* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
3052 const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
3053 absl::Span<const int64_t> new_size_bounds,
3054 const std::vector<bool>& dims_are_dynamic) {
3055 if (new_size_bounds.size() != dims_are_dynamic.size()) {
3056 return InvalidArgument(
3057 "DynamicReshape has to have the same number of elements in new_sizes "
3058 "(%d) and dims_are_dynamic (%d)",
3059 new_size_bounds.size(), dims_are_dynamic.size());
3060 }
3061
3062 for (const Shape* dim_size_shape : dim_size_shapes) {
3063 if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
3064 return InvalidArgument(
3065 "DynamicReshape's dim size has to be scalar S32, got (%s): ",
3066 dim_size_shape->ToString());
3067 }
3068 }
3069
3070 Shape inferred_shape = ShapeUtil::MakeShape(
3071 operand.element_type(), new_size_bounds, dims_are_dynamic);
3072 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3073 return InvalidArgument(
3074 "Reshape operation has mismatched element counts: from=%d (%s) "
3075 "to=%d (%s).",
3076 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3077 ShapeUtil::ElementsIn(inferred_shape),
3078 ShapeUtil::HumanString(inferred_shape));
3079 }
3080 return inferred_shape;
3081 }
3082
InferReshapeShape(const Shape & operand,absl::Span<const int64_t> dimensions,absl::Span<const int64_t> new_sizes,int64_t inferred_dimension)3083 /* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
3084 const Shape& operand, absl::Span<const int64_t> dimensions,
3085 absl::Span<const int64_t> new_sizes, int64_t inferred_dimension) {
3086 TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
3087
3088 Shape inferred_shape =
3089 ShapeUtil::MakeShape(operand.element_type(), new_sizes);
3090 VLOG(3) << "Reshape inferred shape: "
3091 << ShapeUtil::HumanString(inferred_shape);
3092
3093 if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
3094 return InvalidArgument(
3095 "Reshape operation has mismatched element counts: from=%d (%s) "
3096 "to=%d (%s).",
3097 ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
3098 ShapeUtil::ElementsIn(inferred_shape),
3099 ShapeUtil::HumanString(inferred_shape));
3100 }
3101
3102 std::vector<int64_t> indices(operand.rank());
3103 std::iota(indices.begin(), indices.end(), 0);
3104 if (dimensions.size() != operand.rank() ||
3105 !std::is_permutation(dimensions.begin(), dimensions.end(),
3106 indices.begin())) {
3107 return InvalidArgument(
3108 "Reshape dimensions [%s] are not a permutation of the operand "
3109 "dimensions (operand shape is %s).",
3110 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3111 }
3112
3113 // Propagate dynamic dimension.
3114 auto common_factors = CommonFactors(operand.dimensions(), new_sizes);
3115 for (int64_t input_dim = 0; input_dim < operand.rank(); ++input_dim) {
3116 if (!operand.is_dynamic_dimension(input_dim)) {
3117 continue;
3118 }
3119
3120 std::string reshape_debug_str = absl::StrFormat(
3121 "output: %s, input: %s, input_dim: "
3122 "%lld",
3123 ShapeUtil::HumanString(inferred_shape), ShapeUtil::HumanString(operand),
3124 input_dim);
3125
3126 int64_t input_dim_start = -1;
3127 int64_t input_dim_end = -1;
3128 int64_t output_dim_start = -1;
3129 int64_t output_dim_end = -1;
3130 // Find common_factors that the input_dim belongs to.
3131 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
3132 auto start = common_factors[i];
3133 auto end = common_factors[i + 1];
3134 if (input_dim >= start.first && input_dim < end.first) {
3135 input_dim_start = start.first;
3136 input_dim_end = end.first;
3137 output_dim_start = start.second;
3138 output_dim_end = end.second;
3139 break;
3140 }
3141 }
3142 if ((input_dim_end - input_dim_start) > 1 &&
3143 (output_dim_end - output_dim_start) > 1) {
3144 // We don't support the case when a dynamic dimension is both combined
3145 // with and splitted into other dimensions:
3146 //
3147 // [x, yz]
3148 // | Reshape
3149 // [xy, z]
3150 return Unimplemented(
3151 "Dynamic input dimension to reshape that is both splitted and "
3152 "combined is not supported: %s",
3153 reshape_debug_str);
3154 }
3155
3156 for (auto common_factor : common_factors) {
3157 //
3158 // For reshapes like:
3159 // [<=5]
3160 // | Reshape
3161 // [1, 5]
3162 //
3163 // The input dynamic dimension can go into either dynamic dimensions.
3164 // However, the return value of common factors only returns
3165 // input: 5
3166 // output: 5
3167 //
3168 // We need to expand common factor to include degenerated output
3169 // dimensions:
3170 // input: 5
3171 // output: 1, 5
3172 //
3173 // such that in the logic later on we can consider both dimensions as
3174 // candidate.
3175 if (common_factor.first == input_dim_start) {
3176 output_dim_start = std::min(output_dim_start, common_factor.second);
3177 }
3178 if (common_factor.first == input_dim_end) {
3179 output_dim_end = std::max(output_dim_end, common_factor.second);
3180 }
3181 }
3182
3183 // Calculate output dynamic reshape dimension.
3184 int64_t output_dynamic_dimension = -1;
3185
3186 if (operand.dimensions(input_dim) == 1 && !new_sizes.empty()) {
3187 // If dynamic dimension is size 1, it can only be most-major or
3188 // most-minor.
3189 if (input_dim == 0) {
3190 output_dynamic_dimension = 0;
3191 }
3192 if (input_dim == operand.rank() - 1) {
3193 output_dynamic_dimension = new_sizes.size() - 1;
3194 }
3195
3196 if (output_dynamic_dimension == -1) {
3197 return Unimplemented(
3198 "Dynamic degenerated dimension that's not most-minor nor "
3199 "most-major is not supported: %s",
3200 reshape_debug_str);
3201 }
3202 }
3203
3204 if (output_dynamic_dimension == -1 &&
3205 output_dim_end - output_dim_start == 1) {
3206 // Only one possible output dimension.
3207 output_dynamic_dimension = output_dim_start;
3208 }
3209 if (output_dynamic_dimension == -1 &&
3210 output_dim_end - output_dim_start > 1) {
3211 // Multiple outputs can be dynamic, use "inferred_dimension" to tie-break.
3212 output_dynamic_dimension = inferred_dimension;
3213 }
3214
3215 if (output_dynamic_dimension != -1) {
3216 // TODO(yunxing): Turn this into a CHECK.
3217 inferred_shape.set_dynamic_dimension(output_dynamic_dimension, true);
3218 } else {
3219 std::vector<int64_t> output_non_degenerated;
3220 output_non_degenerated.reserve(output_dim_end);
3221 for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
3222 if (new_sizes[i] != 1) {
3223 output_non_degenerated.push_back(i);
3224 }
3225 }
3226 if (output_non_degenerated.size() == 1) {
3227 inferred_shape.set_dynamic_dimension(output_non_degenerated[0], true);
3228 }
3229 }
3230 }
3231
3232 return inferred_shape;
3233 }
3234
InferTransposeShape(const Shape & operand,absl::Span<const int64_t> dimensions)3235 /* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
3236 const Shape& operand, absl::Span<const int64_t> dimensions) {
3237 TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
3238
3239 if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
3240 return InvalidArgument(
3241 "Transpose dimensions [%s] are not a permutation of the operand "
3242 "dimensions (operand shape is %s).",
3243 StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
3244 }
3245
3246 // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
3247 // we need output[i]=input[dimensions[i]] which is
3248 // Permute(Inverse(dimensions),input).
3249 return ShapeUtil::PermuteDimensions(dimensions, operand);
3250 }
3251
InferClampShape(const Shape & min,const Shape & operand,const Shape & max)3252 /* static */ StatusOr<Shape> ShapeInference::InferClampShape(
3253 const Shape& min, const Shape& operand, const Shape& max) {
3254 TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
3255 TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
3256 TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
3257
3258 if (!ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
3259 !ShapeUtil::CompatibleIgnoringFpPrecision(max, operand)) {
3260 return InvalidArgument(
3261 "Clamp with different shapes: %s, %s, %s.", ShapeUtil::HumanString(min),
3262 ShapeUtil::HumanString(operand), ShapeUtil::HumanString(max));
3263 }
3264 return operand;
3265 }
3266
InferSelectShape(const Shape & pred,const Shape & on_true,const Shape & on_false)3267 /* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
3268 const Shape& pred, const Shape& on_true, const Shape& on_false) {
3269 TF_RETURN_IF_ERROR(ExpectArray(pred, "select pred"));
3270 TF_RETURN_IF_ERROR(ExpectArray(on_true, "select on-true"));
3271 TF_RETURN_IF_ERROR(ExpectArray(on_false, "select on-false"));
3272
3273 if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
3274 return InvalidArgument(
3275 "Operands to select must be the same shape; got %s and %s.",
3276 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false));
3277 }
3278 if (pred.element_type() != PRED) {
3279 return InvalidArgument(
3280 "Select's pred operand must have PRED element type; got %s.",
3281 ShapeUtil::HumanString(pred));
3282 }
3283 if (!Shape::Equal()
3284 .IgnoreElementType()
3285 .IgnoreLayout()
3286 .IgnoreDynamicDimension()(pred, on_true)) {
3287 return InvalidArgument(
3288 "Operands to select and predicate must be the same shape; got %s and "
3289 "%s.",
3290 ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(pred));
3291 }
3292
3293 return ShapeUtil::ChangeElementType(
3294 pred, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
3295 }
3296
InferCallShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply)3297 /* static */ StatusOr<Shape> ShapeInference::InferCallShape(
3298 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply) {
3299 // The applied function's arity equals the number of arguments.
3300 if (arg_shapes.size() != to_apply.parameters_size()) {
3301 std::string computation_signature = ShapeUtil::HumanString(to_apply);
3302 std::string argument_shapes =
3303 StrJoin(arg_shapes, ", ", [](std::string* out, const Shape* shape) {
3304 absl::StrAppend(out, ShapeUtil::HumanString(*shape));
3305 });
3306 return InvalidArgument(
3307 "Call applied function arity must match number of arguments; got: "
3308 "arity: %d, arguments: %u; computation signature: %s; argument "
3309 "shapes: [%s].",
3310 to_apply.parameters_size(), arg_shapes.size(), computation_signature,
3311 argument_shapes);
3312 }
3313
3314 // All arguments must be compatible with the program shape.
3315 for (int i = 0; i < arg_shapes.size(); ++i) {
3316 const Shape& arg_shape = *arg_shapes[i];
3317 const Shape& param_shape = to_apply.parameters(i);
3318 if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
3319 return InvalidArgument(
3320 "Call parameter must match argument; got parameter %d shape: %s, "
3321 "argument shape: %s.",
3322 i, ShapeUtil::HumanString(param_shape),
3323 ShapeUtil::HumanString(arg_shape));
3324 }
3325 }
3326
3327 return to_apply.result();
3328 }
3329
ValidateGatherDimensionNumbers(const Shape & input_shape,absl::Span<const int64_t> start_indices_shape,const GatherDimensionNumbers & dim_numbers)3330 static Status ValidateGatherDimensionNumbers(
3331 const Shape& input_shape, absl::Span<const int64_t> start_indices_shape,
3332 const GatherDimensionNumbers& dim_numbers) {
3333 if (!absl::c_is_sorted(dim_numbers.offset_dims())) {
3334 return InvalidArgument(
3335 "Output window dimensions in gather op must be ascending; got: %s.",
3336 StrJoin(dim_numbers.offset_dims(), ", "));
3337 }
3338
3339 if (absl::c_adjacent_find(dim_numbers.offset_dims()) !=
3340 dim_numbers.offset_dims().end()) {
3341 return InvalidArgument(
3342 "Output window dimensions in gather op must not repeat; got: %s.",
3343 StrJoin(dim_numbers.offset_dims(), ", "));
3344 }
3345
3346 const int64_t output_offset_dim_count = dim_numbers.offset_dims_size();
3347 const int64_t output_shape_rank =
3348 output_offset_dim_count + start_indices_shape.size() - 1;
3349
3350 for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) {
3351 int64_t offset_dim = dim_numbers.offset_dims(i);
3352 if (offset_dim < 0 || offset_dim >= output_shape_rank) {
3353 return InvalidArgument(
3354 "Offset dimension %d in gather op is out of bounds; got %d, but "
3355 "should "
3356 "have been in [0,%d).",
3357 i, offset_dim, output_shape_rank);
3358 }
3359 }
3360
3361 if (dim_numbers.start_index_map_size() !=
3362 start_indices_shape[dim_numbers.index_vector_dim()]) {
3363 return InvalidArgument(
3364 "Gather op has %d elements in start_index_map and the "
3365 "bound of dimension index_vector_dim=%d of start_indices is "
3366 "%d. These two numbers must be equal.",
3367 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(),
3368 start_indices_shape[dim_numbers.index_vector_dim()]);
3369 }
3370
3371 for (int i = 0; i < dim_numbers.start_index_map_size(); i++) {
3372 int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i);
3373 if (operand_dim_for_start_index_i < 0 ||
3374 operand_dim_for_start_index_i >= input_shape.dimensions_size()) {
3375 return InvalidArgument(
3376 "Invalid start_index_map; domain is [0, %d), got: %d->%d.",
3377 input_shape.dimensions_size(), i, operand_dim_for_start_index_i);
3378 }
3379 }
3380
3381 std::vector<int64_t> sorted_start_index_map(
3382 dim_numbers.start_index_map().begin(),
3383 dim_numbers.start_index_map().end());
3384
3385 absl::c_sort(sorted_start_index_map);
3386
3387 if (absl::c_adjacent_find(sorted_start_index_map) !=
3388 sorted_start_index_map.end()) {
3389 return InvalidArgument(
3390 "Repeated dimensions are not allowed in start_index_map; "
3391 "got: %s.",
3392 StrJoin(dim_numbers.start_index_map(), ", "));
3393 }
3394
3395 for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) {
3396 if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) {
3397 return InvalidArgument(
3398 "Invalid collapsed_slice_dims set in gather op; valid range is [0, "
3399 "%d), got: %d.",
3400 input_shape.dimensions_size(), collapsed_dim);
3401 }
3402 }
3403
3404 if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) {
3405 return InvalidArgument(
3406 "collapsed_slice_dims in gather op must be sorted; got: %s",
3407 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3408 }
3409
3410 if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) !=
3411 dim_numbers.collapsed_slice_dims().end()) {
3412 return InvalidArgument(
3413 "Repeated dimensions not allowed in collapsed_slice_dims in gather op; "
3414 "got: %s.",
3415 StrJoin(dim_numbers.collapsed_slice_dims(), ", "));
3416 }
3417
3418 return OkStatus();
3419 }
3420
InferGatherShape(const Shape & input_shape,const Shape & start_indices_shape,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64_t> slice_sizes)3421 /*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
3422 const Shape& input_shape, const Shape& start_indices_shape,
3423 const GatherDimensionNumbers& gather_dim_numbers,
3424 absl::Span<const int64_t> slice_sizes) {
3425 TF_RETURN_IF_ERROR(
3426 ExpectArray(input_shape, "input tensor operand of gather op"));
3427 TF_RETURN_IF_ERROR(
3428 ExpectArray(start_indices_shape, "gather indices operand of gather op"));
3429
3430 if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
3431 return InvalidArgument(
3432 "Gather indices parameter must be an integral tensor; got %s.",
3433 ShapeUtil::HumanString(start_indices_shape));
3434 }
3435
3436 // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if
3437 // index_vector_dim is rank(P). The bounds of this expanded shape is
3438 // stored in expanded_start_indices_shape.
3439
3440 if (start_indices_shape.dimensions_size() <
3441 gather_dim_numbers.index_vector_dim() ||
3442 gather_dim_numbers.index_vector_dim() < 0) {
3443 return InvalidArgument(
3444 "Gather index leaf dimension must be within [0, rank(start_indices) + "
3445 "1). rank(start_indices) is %d and gather index leaf dimension is "
3446 "%d.",
3447 start_indices_shape.dimensions_size(),
3448 gather_dim_numbers.index_vector_dim());
3449 }
3450
3451 std::vector<int64_t> expanded_start_indices_shape;
3452 // Also tracks if an output dimension is dynamic.
3453 std::vector<bool> expanded_start_indices_shape_dynamic_dimensions;
3454 expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size());
3455 expanded_start_indices_shape_dynamic_dimensions.reserve(
3456 start_indices_shape.dimensions_size());
3457 absl::c_copy(start_indices_shape.dimensions(),
3458 std::back_inserter(expanded_start_indices_shape));
3459 absl::c_copy(
3460 start_indices_shape.dynamic_dimensions(),
3461 std::back_inserter(expanded_start_indices_shape_dynamic_dimensions));
3462 if (expanded_start_indices_shape.size() ==
3463 gather_dim_numbers.index_vector_dim()) {
3464 expanded_start_indices_shape.push_back(1);
3465 expanded_start_indices_shape_dynamic_dimensions.push_back(false);
3466 }
3467
3468 TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
3469 input_shape, expanded_start_indices_shape, gather_dim_numbers));
3470
3471 if (slice_sizes.size() != input_shape.dimensions_size()) {
3472 return InvalidArgument(
3473 "Gather op must have one slice size for every input dimension; got: "
3474 "len(slice_sizes)=%lu, input_shape.rank=%d.",
3475 slice_sizes.size(), input_shape.dimensions_size());
3476 }
3477
3478 if (slice_sizes.size() !=
3479 gather_dim_numbers.offset_dims_size() +
3480 gather_dim_numbers.collapsed_slice_dims_size()) {
3481 return InvalidArgument(
3482 "All components of the offset index in a gather op must either be a "
3483 "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, "
3484 "output_slice_sizes=%s, collapsed_slice_dims=%s.",
3485 slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","),
3486 StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","));
3487 }
3488
3489 for (int i = 0; i < slice_sizes.size(); i++) {
3490 int64_t slice_size = slice_sizes[i];
3491 int64_t corresponding_input_size = input_shape.dimensions(i);
3492 if (slice_size < 0 || slice_size > corresponding_input_size) {
3493 return InvalidArgument(
3494 "Slice size at index %d in gather op is out of range, must be "
3495 "within [0, %d), got %d.",
3496 i, corresponding_input_size + 1, slice_size);
3497 }
3498 }
3499
3500 for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) {
3501 if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] > 1) {
3502 return InvalidArgument(
3503 "Gather op can only collapse slice dims with bound 1 or 0, but bound "
3504 "is %d for index %d at position %d.",
3505 slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)],
3506 gather_dim_numbers.collapsed_slice_dims(i), i);
3507 }
3508 }
3509
3510 int64_t result_rank = gather_dim_numbers.offset_dims_size() +
3511 (expanded_start_indices_shape.size() - 1);
3512 int64_t offset_dims_seen = 0;
3513 int64_t gather_dims_seen = 0;
3514 std::vector<int64_t> output_dim_bounds;
3515 output_dim_bounds.reserve(result_rank);
3516
3517 std::vector<bool> output_dim_is_dynamic;
3518 output_dim_is_dynamic.reserve(result_rank);
3519 for (int64_t i = 0; i < result_rank; i++) {
3520 int64_t current_bound;
3521 bool dim_dynamic = false;
3522 bool is_window_index =
3523 absl::c_binary_search(gather_dim_numbers.offset_dims(), i);
3524 if (is_window_index) {
3525 while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(),
3526 offset_dims_seen)) {
3527 offset_dims_seen++;
3528 }
3529 // Gathering an entire dynamic dimension creates dynamic dimension.
3530 //
3531 // e.g.,:
3532 //
3533 // gather(input: [1,<=2,1], slice_sizes={1,2,1})
3534 //
3535 // creates
3536 //
3537 // [<=2, 1]
3538 if (slice_sizes[offset_dims_seen] ==
3539 input_shape.dimensions(offset_dims_seen)) {
3540 dim_dynamic = input_shape.is_dynamic_dimension(offset_dims_seen);
3541 }
3542 current_bound = slice_sizes[offset_dims_seen++];
3543 } else {
3544 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) {
3545 gather_dims_seen++;
3546 }
3547 // Forward dynamic dimensions from indices.
3548 dim_dynamic =
3549 expanded_start_indices_shape_dynamic_dimensions[gather_dims_seen];
3550
3551 current_bound = expanded_start_indices_shape[gather_dims_seen++];
3552 }
3553 output_dim_is_dynamic.push_back(dim_dynamic);
3554 output_dim_bounds.push_back(current_bound);
3555 }
3556
3557 return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds,
3558 output_dim_is_dynamic);
3559 }
3560
3561 namespace {
3562
ValidateScatterDimensionNumbers(const Shape & operand_shape,absl::Span<const int64_t> scatter_indices_shape,const Shape & updates_shape,const ScatterDimensionNumbers & dim_numbers)3563 Status ValidateScatterDimensionNumbers(
3564 const Shape& operand_shape, absl::Span<const int64_t> scatter_indices_shape,
3565 const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
3566 // Validate update_window_dims in ScatterDimensionNumbers.
3567 if (!absl::c_is_sorted(dim_numbers.update_window_dims())) {
3568 return InvalidArgument(
3569 "update_window_dims in scatter op must be sorted; got: %s.",
3570 StrJoin(dim_numbers.update_window_dims(), ", "));
3571 }
3572 if (absl::c_adjacent_find(dim_numbers.update_window_dims()) !=
3573 dim_numbers.update_window_dims().end()) {
3574 return InvalidArgument(
3575 "update_window_dims in scatter op must not repeat; got: %s.",
3576 StrJoin(dim_numbers.update_window_dims(), ", "));
3577 }
3578 const int64_t updates_rank = updates_shape.rank();
3579 for (int64_t window_dim : dim_numbers.update_window_dims()) {
3580 if (window_dim < 0 || window_dim >= updates_rank) {
3581 return InvalidArgument(
3582 "Invalid update_window_dims set in scatter op; valid range is [0, "
3583 "%d). got: %d.",
3584 updates_rank, window_dim);
3585 }
3586 }
3587
3588 // Validate inserted_window_dims in ScatterDimensionNumbers.
3589 if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) {
3590 return InvalidArgument(
3591 "inserted_window_dims in scatter op must be sorted; got: %s.",
3592 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3593 }
3594 if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) !=
3595 dim_numbers.inserted_window_dims().end()) {
3596 return InvalidArgument(
3597 "inserted_window_dims in scatter op must not repeat; got: %s.",
3598 StrJoin(dim_numbers.inserted_window_dims(), ", "));
3599 }
3600 for (int64_t inserted_dim : dim_numbers.inserted_window_dims()) {
3601 if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
3602 return InvalidArgument(
3603 "Invalid inserted_window_dims set in scatter op; valid range is [0, "
3604 "%d), got: %d.",
3605 operand_shape.dimensions_size(), inserted_dim);
3606 }
3607 }
3608
3609 // Validate window size.
3610 auto window_size = dim_numbers.update_window_dims_size() +
3611 dim_numbers.inserted_window_dims_size();
3612 if (window_size != operand_shape.rank()) {
3613 return InvalidArgument(
3614 "Scatter op has window of size %d; doesn't match operand of rank %d.",
3615 window_size, operand_shape.rank());
3616 }
3617
3618 // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
3619 if (dim_numbers.scatter_dims_to_operand_dims_size() !=
3620 scatter_indices_shape[dim_numbers.index_vector_dim()]) {
3621 return InvalidArgument(
3622 "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
3623 "bound of dimension index_vector_dim=%d of scatter_indices is %d. "
3624 "These two numbers must be equal.",
3625 dim_numbers.scatter_dims_to_operand_dims_size(),
3626 dim_numbers.index_vector_dim(),
3627 scatter_indices_shape[dim_numbers.index_vector_dim()]);
3628 }
3629 for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
3630 int64_t scatter_dim_to_operand_dim =
3631 dim_numbers.scatter_dims_to_operand_dims(i);
3632 if (scatter_dim_to_operand_dim < 0 ||
3633 scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
3634 return InvalidArgument(
3635 "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
3636 "got: %d->%d.",
3637 operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
3638 }
3639 }
3640 std::vector<int64_t> sorted_scatter_dims_to_operand_dims(
3641 dim_numbers.scatter_dims_to_operand_dims().begin(),
3642 dim_numbers.scatter_dims_to_operand_dims().end());
3643 absl::c_sort(sorted_scatter_dims_to_operand_dims);
3644 if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
3645 sorted_scatter_dims_to_operand_dims.end()) {
3646 return InvalidArgument(
3647 "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
3648 "got: %s.",
3649 StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", "));
3650 }
3651
3652 return OkStatus();
3653 }
3654
3655 } // namespace
3656
InferScatterShape(absl::Span<const Shape * const> arg_shapes,const ProgramShape & to_apply_shape,const ScatterDimensionNumbers & scatter_dim_numbers)3657 /*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
3658 absl::Span<const Shape* const> arg_shapes,
3659 const ProgramShape& to_apply_shape,
3660 const ScatterDimensionNumbers& scatter_dim_numbers) {
3661 const int64_t operand_count = arg_shapes.size() / 2;
3662 if (operand_count * 2 + 1 != arg_shapes.size()) {
3663 return InvalidArgument(
3664 "Invalid argument count of scatter op: Expected %d, saw %d",
3665 operand_count * 2 + 1, arg_shapes.size());
3666 }
3667
3668 const Shape& scatter_indices_shape = *arg_shapes[operand_count];
3669 TF_RETURN_IF_ERROR(
3670 ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
3671 if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
3672 return InvalidArgument(
3673 "Scatter indices parameter must be an integral tensor; got %s.",
3674 ShapeUtil::HumanString(scatter_indices_shape));
3675 }
3676 if (scatter_indices_shape.dimensions_size() <
3677 scatter_dim_numbers.index_vector_dim() ||
3678 scatter_dim_numbers.index_vector_dim() < 0) {
3679 return InvalidArgument(
3680 "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
3681 " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
3682 "is %d.",
3683 scatter_indices_shape.dimensions_size(),
3684 scatter_dim_numbers.index_vector_dim());
3685 }
3686
3687 std::vector<int64_t> expanded_scatter_indices_shape =
3688 SpanToVector(scatter_indices_shape.dimensions());
3689 if (expanded_scatter_indices_shape.size() ==
3690 scatter_dim_numbers.index_vector_dim()) {
3691 expanded_scatter_indices_shape.push_back(1);
3692 }
3693
3694 auto operand_shapes = arg_shapes.first(operand_count);
3695 auto updates_shapes = arg_shapes.last(operand_count);
3696 for (int64_t operand_i = 0; operand_i < operand_count; ++operand_i) {
3697 const Shape& operand_shape = *operand_shapes[operand_i];
3698 const Shape& updates_shape = *updates_shapes[operand_i];
3699 TF_RETURN_IF_ERROR(ExpectArray(
3700 operand_shape, absl::StrCat("operand ", operand_i, " of scatter op")));
3701 TF_RETURN_IF_ERROR(ExpectArray(
3702 updates_shape, absl::StrCat("updates ", operand_i, " of scatter op")));
3703
3704 int64_t inserted_dims_seen = 0;
3705 std::vector<int64_t> max_update_slice_sizes;
3706 const auto dimensions_size = operand_shape.dimensions_size();
3707 max_update_slice_sizes.reserve(dimensions_size);
3708 for (int i = 0; i < dimensions_size; ++i) {
3709 if (inserted_dims_seen <
3710 scatter_dim_numbers.inserted_window_dims_size() &&
3711 scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
3712 ++inserted_dims_seen;
3713 } else {
3714 max_update_slice_sizes.push_back(operand_shape.dimensions(i));
3715 }
3716 }
3717 int64_t expected_updates_rank =
3718 expanded_scatter_indices_shape.size() - 1 +
3719 scatter_dim_numbers.update_window_dims_size();
3720 if (updates_shape.rank() != expected_updates_rank) {
3721 return InvalidArgument("Updates tensor must be of rank %d; got %d.",
3722 expected_updates_rank, updates_shape.rank());
3723 }
3724
3725 TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
3726 operand_shape, expanded_scatter_indices_shape, updates_shape,
3727 scatter_dim_numbers));
3728
3729 for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
3730 auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
3731 if (updates_shape.dimensions(update_window_dim) >
3732 max_update_slice_sizes[i]) {
3733 return InvalidArgument(
3734 "Bounds of the window dimensions of updates must not exceed the "
3735 "bounds of the corresponding dimensions of operand. For dimension "
3736 "%d, updates bound is %d, operand bound is %d.",
3737 update_window_dim, updates_shape.dimensions(update_window_dim),
3738 max_update_slice_sizes[i]);
3739 }
3740 }
3741
3742 int64_t scatter_dims_seen = 0;
3743 for (int64_t i = 0; i < updates_shape.rank(); ++i) {
3744 bool is_update_window_dim =
3745 absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i);
3746 if (is_update_window_dim) {
3747 continue;
3748 }
3749 if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
3750 ++scatter_dims_seen;
3751 }
3752 if (updates_shape.dimensions(i) !=
3753 expanded_scatter_indices_shape[scatter_dims_seen]) {
3754 return InvalidArgument(
3755 "Bounds of the scatter dimensions of updates must be same as the "
3756 "bounds of the corresponding dimensions of scatter indices. For "
3757 "scatter dimension %d, updates bound is %d, scatter_indices "
3758 "bound is %d.",
3759 i, updates_shape.dimensions(i),
3760 expanded_scatter_indices_shape[scatter_dims_seen]);
3761 }
3762 ++scatter_dims_seen;
3763 }
3764 }
3765
3766 // Check if the update computation has a proper shape as a reduction.
3767 absl::InlinedVector<Shape, 1> init_element_shapes;
3768 absl::InlinedVector<const Shape*, 1> init_element_shape_ptrs;
3769 absl::InlinedVector<PrimitiveType, 1> updates_element_types;
3770 init_element_shapes.reserve(operand_count);
3771 init_element_shape_ptrs.reserve(operand_count);
3772 updates_element_types.reserve(operand_count);
3773 for (int64_t i = 0; i < operand_count; ++i) {
3774 init_element_shapes.push_back(
3775 ShapeUtil::MakeShape(operand_shapes[i]->element_type(), {}));
3776 init_element_shape_ptrs.push_back(&init_element_shapes.back());
3777 updates_element_types.push_back(updates_shapes[i]->element_type());
3778 }
3779 TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_element_shape_ptrs,
3780 updates_element_types, operand_count));
3781
3782 return operand_count == 1 ? *operand_shapes[0]
3783 : ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
3784 }
3785
3786 } // namespace xla
3787