1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shape inference is used by the XLA service as the user builds up 17 // computation requests. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 21 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace xla { 34 35 // For a given operation and input shapes, infers what the resulting shape is 36 // for the operation. With this functionality, the user does not need to specify 37 // the expected result type for computations that are built up via the API -- 38 // the shape that results from an operation is inferred. Some methods have 39 // overloads for inferring shape at the HLO level. 40 // 41 // TODO(b/73352135): Shape inference does not issue very good error messages, in 42 // part because HloInstruction::ToString() is not available since shape 43 // inference runs before the HloInstruction object is created. We need a 44 // solution for this. 45 class ShapeInference { 46 public: 47 // Infers the shape produced by applying the given unary operation to the 48 // given input shape. 49 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 50 const Shape& shape); 51 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 52 const HloInstruction* operand); 53 54 // Infers the shape produced by applying the given binary operation to the 55 // given input shapes. 56 static StatusOr<Shape> InferBinaryOpShape( 57 HloOpcode opcode, const Shape& lhs, const Shape& rhs, 58 absl::Span<const int64> broadcast_dimensions); 59 static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode, 60 const HloInstruction* lhs, 61 const HloInstruction* rhs); 62 63 // Infers the shape produced by applying the given ternary operation to the 64 // given input shapes. 65 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, 66 const Shape& rhs, 67 const Shape& ehs); 68 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, 69 const HloInstruction* lhs, 70 const HloInstruction* rhs, 71 const HloInstruction* ehs); 72 73 // Infers the shape produced by applying the given variadic operation to the 74 // given input operand shapes. 75 static StatusOr<Shape> InferVariadicOpShape( 76 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes); 77 static StatusOr<Shape> InferVariadicOpShape( 78 HloOpcode opcode, absl::Span<const HloInstruction* const> operands); 79 80 // Infers the shape produced by applying the given mapping computation shape 81 // to the given operand shapes. 82 static StatusOr<Shape> InferMapShape( 83 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, 84 absl::Span<const int64> dimensions); 85 86 // Infers the shape produced by InferBatchNormTraining with the given 87 // operands. 88 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape, 89 const Shape& scale_shape, 90 const Shape& offset_shape, 91 int64_t feature_index); 92 93 // Infers the shape produced by InferBatchNormInference with the given 94 // operands. 95 static StatusOr<Shape> InferBatchNormInferenceShape( 96 const Shape& operand_shape, const Shape& scale_shape, 97 const Shape& offset_shape, const Shape& mean_shape, 98 const Shape& variance_shape, int64_t feature_index); 99 100 // Infers the shape produced by InferBatchNormGrad with the given operands. 101 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape, 102 const Shape& scale_shape, 103 const Shape& mean_shape, 104 const Shape& var_shape, 105 const Shape& output_grad_shape, 106 int64_t feature_index); 107 108 // Infers the shape produced by applying the given convolutional filter (rhs) 109 // to lhs in the way specified by the fields on window. An optional 110 // preferred_element_type can be specified to upcast the element type. 111 static StatusOr<Shape> InferConvolveShape( 112 const Shape& lhs, const Shape& rhs, int64_t feature_group_count, 113 int64_t batch_group_count, const Window& window, 114 const ConvolutionDimensionNumbers& dimension_numbers, 115 absl::optional<PrimitiveType> preferred_element_type); 116 117 // Infers the shape produced by the given FFT type on the given operand. 118 static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, 119 absl::Span<const int64> fft_length); 120 121 // Infers the shape produced by the given triangular solve operation. 122 static StatusOr<Shape> InferTriangularSolveShape( 123 const Shape& a, const Shape& b, const TriangularSolveOptions& options); 124 125 // Infers the shape produced by the given triangular solve operation. 126 static StatusOr<Shape> InferCholeskyShape(const Shape& a); 127 128 // Infers the shape produced by an all-gather with the given operand shape, 129 // concat dimension, and shard count. 130 static StatusOr<Shape> InferAllGatherShape( 131 absl::Span<const Shape* const> operand_shapes, 132 int64_t all_gather_dimension, int64_t shard_count); 133 134 // Infers the shape produced by an all-gather-start with the given operand 135 // shape, concat dimension, and shard count. 136 static StatusOr<Shape> InferAllGatherStartShape( 137 absl::Span<const Shape* const> operand_shapes, 138 int64_t all_gather_dimension, int64_t shard_count); 139 140 // Infers the shape produced by an all-gather-done given a certain 141 // all-gather-start shape. 142 static StatusOr<Shape> InferAllGatherDoneShape( 143 const Shape& all_gather_start_shape); 144 145 // Infers the shape produced by a cross replica sum with the given operand 146 // shapes. 147 static StatusOr<Shape> InferAllReduceShape( 148 absl::Span<const Shape* const> operand_shapes); 149 150 // Infers the shape produced by a reduce-scatter with the given operand 151 // shape, scatter dimension, and shard count. 152 static StatusOr<Shape> InferReduceScatterShape( 153 absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension, 154 int64_t shard_count); 155 156 // Infers the shape produced by a cross replica sum start. 157 static StatusOr<Shape> InferAllReduceStartShape( 158 absl::Span<const Shape* const> operand_shapes); 159 160 // Infers the shape produced by a cross replica sum done. 161 static StatusOr<Shape> InferAllReduceDoneShape(const Shape& operand_shape); 162 163 // Infers final shape of an Alltoall operation that is created by the xla 164 // builder. 165 static StatusOr<Shape> InferAllToAllShape(const Shape& shape, 166 int64_t split_dimension, 167 int64_t concat_dimension, 168 int64_t split_count); 169 170 // Infers the shape of an HLO all-to-all instruction. 171 static StatusOr<Shape> InferAllToAllTupleShape( 172 absl::Span<const Shape* const> operand_shapes); 173 174 // Infers the shape of a collective permute operation. 175 static StatusOr<Shape> InferCollectivePermuteShape( 176 absl::Span<const Shape* const> operand_shapes); 177 178 // Infers the shape of a collective permute start operation. 179 static StatusOr<Shape> InferCollectivePermuteStartShape( 180 absl::Span<const Shape* const> operand_shapes); 181 182 // Infers the shape of a collective permute operation. 183 static StatusOr<Shape> InferCollectivePermuteDoneShape( 184 const Shape& operand_shape); 185 186 // Infers the shape produced by applying the given reduction computation 187 // shape to the given input operand shape. 188 // 189 // If pass_index is true, the reduce function is invoked with the element 190 // index as the leading parameter, and the program shape should match 191 // accordingly (or an error will result). 192 static StatusOr<Shape> InferReduceShape( 193 absl::Span<const Shape* const> arg_shapes, 194 absl::Span<const int64> dimensions_to_reduce, 195 const ProgramShape& to_apply); 196 197 // Infers the shape produced by applying the given computation to the operand 198 // shape with the given window and stride dimensions. 199 static StatusOr<Shape> InferReduceWindowShape( 200 const Shape& operand_shape, const Shape& init_value, const Window& window, 201 const ProgramShape& to_apply_shape); 202 static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape, 203 const Shape& init_value, 204 const Window& window); 205 static StatusOr<Shape> InferReduceWindowShape( 206 absl::Span<const Shape* const> operands, 207 absl::Span<const Shape* const> init_values, const Window& window, 208 const ProgramShape& to_apply_shape); 209 210 static StatusOr<Shape> InferReduceWindowShape( 211 absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, 212 const Window& window); 213 214 // Infers the shape produced by scattering the given source shape to the 215 // selected indices of each window on the operand shape. 216 static StatusOr<Shape> InferSelectAndScatterShape( 217 const Shape& operand_shape, const ProgramShape& select_shape, 218 const Window& window, const Shape& source_shape, 219 const Shape& init_value_shape, const ProgramShape& scatter_shape); 220 221 // Infers the shape produced by a reverse operation that reverses the order 222 // of the elements in the given dimensions. 223 static StatusOr<Shape> InferReverseShape(const Shape& operand_shape, 224 absl::Span<const int64> dimensions); 225 226 // Infers the shape produced by a slice operation spanning from the starts to 227 // the limits in the original shape's dimensions. 228 // 229 // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] 230 static StatusOr<Shape> InferSliceShape(const Shape& arg, 231 absl::Span<const int64> starts, 232 absl::Span<const int64> limits, 233 absl::Span<const int64> strides); 234 235 // Infers the shape produced by a dynamic slice operation of size specified 236 // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. 237 static StatusOr<Shape> InferDynamicSliceShape( 238 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes, 239 absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true); 240 241 // Infers the shape produced by a dynamic update slice operation based 242 // on the shape of operand and update. 243 static StatusOr<Shape> InferDynamicUpdateSliceShape( 244 const Shape& operand_shape, const Shape& update_shape, 245 absl::Span<const Shape> start_index_shapes, 246 bool allow_scalar_indices = true); 247 248 // Infers the shape produced by doing a compile-time-constant indexing into 249 // the given input shape. This is essential for operations on tuples, because 250 // it is impossible to infer the type that comes out of the tuple indexing if 251 // it is not a compile time constant. 252 static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg, 253 int64_t index); 254 255 // Infers the shape produced from a while node. condition and body are the 256 // shapes of computations for the condition and the body of a while node, and 257 // init is the shape of data initially passed in to the body as an argument. 258 // The shapes must match; condition: T -> PRED, body: T -> T, init: T 259 static StatusOr<Shape> InferWhileShape(const ProgramShape& condition, 260 const ProgramShape& body, 261 const Shape& init); 262 263 // Infers the shape produced by a predicated or indexed conditional operation. 264 static StatusOr<Shape> InferConditionalShape( 265 const Shape& branch_index, 266 absl::Span<const ProgramShape> branch_computations, 267 absl::Span<const Shape> branch_operands); 268 269 // Infers the shape produced by a broadcast operation. 270 static StatusOr<Shape> InferBroadcastShape( 271 const Shape& operand, absl::Span<const int64> broadcast_sizes); 272 273 // Checks whether the given parameters can form a broadcast. Returns the same 274 // output_shape if it's legal. 275 static StatusOr<Shape> InferBroadcastShape( 276 const Shape& operand_shape, const Shape& output_shape, 277 absl::Span<const int64> broadcast_dimensions); 278 279 // Infers the shape produced by a reshape operation from the element type of 280 // its operand and the new dimension sizes specified. 281 static StatusOr<Shape> InferReshapeShape(const Shape& operand, 282 absl::Span<const int64> dimensions, 283 absl::Span<const int64> new_sizes, 284 int64_t inferred_dimension); 285 286 // Infers the shape produced by a dynamic reshape operation from the element 287 // type of its operand and the new dimension sizes specified. The result shape 288 // will have dynamic dimensions as specific in `dim_is_dynamic` and bound 289 // `new_size_bounds`. 290 static StatusOr<Shape> InferDynamicReshapeShape( 291 const Shape& operand, absl::Span<const Shape* const> dim_size_shapes, 292 absl::Span<const int64> new_size_bounds, 293 const std::vector<bool>& dims_are_dynamic); 294 295 // Infers the shape produced by a transpose operation from the element type of 296 // its operand and its dimensions field. 297 static StatusOr<Shape> InferTransposeShape( 298 const Shape& operand, absl::Span<const int64> dimensions); 299 300 // Helper that infers the shape produced by performing a concatenate operation 301 // with the given operand shapes. 302 static StatusOr<Shape> InferConcatOpShape( 303 absl::Span<const Shape* const> arg_shapes, int64_t dimension); 304 305 // Helper that validates the given operand shape can be converted to the 306 // target output_shape via a convert instruction -- the requirement is that 307 // the shape is identical except for the element type. 308 static StatusOr<Shape> InferConvertShape(const Shape& operand_shape, 309 PrimitiveType new_element_type); 310 311 // Helper that validates the given operand shape can be bitcast converted to 312 // the target output_shape via a bitcast convert instruction -- the 313 // requirement is that the shape is identical except for the element type and 314 // the element types have identical bit-widths. 315 static StatusOr<Shape> InferBitcastConvertShape( 316 const Shape& operand_shape, PrimitiveType new_element_type); 317 318 // Helper that validates the input data type for a reduce-precision operation, 319 // and returns the result shape. 320 static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape, 321 const int exponent_bits, 322 const int mantissa_bits); 323 324 // Helper that infers the shape produced by a pad operation based on the 325 // padding configuration. 326 static StatusOr<Shape> InferPadShape(const Shape& operand_shape, 327 const Shape& padding_value_shape, 328 const PaddingConfig& padding_config); 329 330 // Helper that validates the given arg_shapes are compatible with the shape of 331 // the to_apply parameters, and returns the to_apply result shape. 332 static StatusOr<Shape> InferCallShape( 333 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply); 334 335 // Helper that infers the shape produced by performing a dot operation with 336 // the given LHS and RHS shapes. An optional preferred_element_type can be 337 // specified to upcast the element type. 338 static StatusOr<Shape> InferDotOpShape( 339 const Shape& lhs, const Shape& rhs, 340 const DotDimensionNumbers& dimension_numbers, 341 absl::optional<PrimitiveType> preferred_element_type); 342 343 // Helper that infers the shape of the tensor produced by a gather operation 344 // with the given input shape, gather indices shape and gather dimension 345 // numbers. 346 static StatusOr<Shape> InferGatherShape( 347 const Shape& input_shape, const Shape& start_indices_shape, 348 const GatherDimensionNumbers& gather_dim_numbers, 349 absl::Span<const int64> slice_sizes); 350 351 // Helper that validates the given input shape, scatter indices shape, updates 352 // shape, and scatter dimension numbers that constitute a scatter operation, 353 // and returns the result shape of the scatter operation. 354 static StatusOr<Shape> InferScatterShape( 355 const Shape& operand_shape, const Shape& scatter_indices_shape, 356 const Shape& updates_shape, const ProgramShape& to_apply_shape, 357 const ScatterDimensionNumbers& scatter_dim_numbers); 358 359 // Helper that validates the given input shape to GetDimensionSize. 360 static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape, 361 int64_t dimension); 362 363 // Helper that validates the given input shape to SetDimensionSize. 364 static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape, 365 const Shape& val_shape, 366 int64_t dimension); 367 368 // Helper function for creating a Window proto from user-supplied data. 369 // Returns error if the user-supplied data was invalid. 370 static StatusOr<Window> InferWindowFromDimensions( 371 absl::Span<const int64> window_dimensions, 372 absl::Span<const int64> window_strides, 373 absl::Span<const std::pair<int64, int64>> padding, 374 absl::Span<const int64> lhs_dilation, 375 absl::Span<const int64> rhs_dilation); 376 377 private: 378 // Helper that infers the shape produced by performing an element-wise binary 379 // operation with the given LHS and RHS shapes. 380 // Note: By "element-wise" we mean operations that look at a single element in 381 // the LHS and a single element in the RHS to produce a single output element, 382 // even in the presence of broadcasting of one of the operands over the other. 383 static StatusOr<Shape> InferElementwiseBinaryOpShape( 384 HloOpcode operation, const Shape& lhs, const Shape& rhs, 385 absl::Span<const int64> broadcast_dimensions); 386 387 // Helper for inferring the shape of Clamp ops. 388 static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand, 389 const Shape& max); 390 391 // Helper for inferring the shape of Select ops. 392 static StatusOr<Shape> InferSelectShape(const Shape& pred, 393 const Shape& on_true, 394 const Shape& on_false); 395 // Helper for inferring the shape of TupleSelect ops. 396 static StatusOr<Shape> InferTupleSelectShape(const Shape& pred, 397 const Shape& on_true, 398 const Shape& on_false); 399 400 // Helper for inferring shapes of binary operations which use degenerate 401 // dimension broadcasting (a dimension of size 1 in one operand is broadcast 402 // up to match the size of the dimension in the other operand). 403 static StatusOr<Shape> InferDegenerateDimensionBroadcastShape( 404 HloOpcode operation, const Shape& lhs, const Shape& rhs); 405 406 // Helper for inferring shapes of binary operations using "InDim" 407 // broadcasting. This is the broadcasting used in the *InDim binary operations 408 // (for example ComputationBuilder::AddInDim). smaller_shape must be a 409 // lower-rank shape than larger_shape. Returns the shape that the 410 // smaller_shape is broadcast to. 411 static StatusOr<Shape> InferInDimBroadcastShape( 412 const Shape& smaller_shape, const Shape& larger_shape, 413 absl::Span<const int64> broadcast_dimensions); 414 415 TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); 416 }; 417 418 } // namespace xla 419 420 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 421