1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shape inference is used by the XLA service as the user builds up 17 // computation requests. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 21 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace xla { 34 35 // For a given operation and input shapes, infers what the resulting shape is 36 // for the operation. With this functionality, the user does not need to specify 37 // the expected result type for computations that are built up via the API -- 38 // the shape that results from an operation is inferred. Some methods have 39 // overloads for inferring shape at the HLO level. 40 // 41 // TODO(b/73352135): Shape inference does not issue very good error messages, in 42 // part because HloInstruction::ToString() is not available since shape 43 // inference runs before the HloInstruction object is created. We need a 44 // solution for this. 45 class ShapeInference { 46 public: 47 // Infers the shape produced by applying the given unary operation to the 48 // given input shape. 49 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 50 const Shape& shape); 51 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 52 const HloInstruction* operand); 53 54 // Infers the shape produced by applying the given binary operation to the 55 // given input shapes. 56 static StatusOr<Shape> InferBinaryOpShape( 57 HloOpcode opcode, const Shape& lhs, const Shape& rhs, 58 absl::Span<const int64> broadcast_dimensions); 59 static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode, 60 const HloInstruction* lhs, 61 const HloInstruction* rhs); 62 63 // Infers the shape produced by applying the given ternary operation to the 64 // given input shapes. 65 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, 66 const Shape& rhs, 67 const Shape& ehs); 68 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, 69 const HloInstruction* lhs, 70 const HloInstruction* rhs, 71 const HloInstruction* ehs); 72 73 // Infers the shape produced by applying the given variadic operation to the 74 // given input operand shapes. 75 static StatusOr<Shape> InferVariadicOpShape( 76 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes); 77 static StatusOr<Shape> InferVariadicOpShape( 78 HloOpcode opcode, absl::Span<const HloInstruction* const> operands); 79 80 // Infers the shape produced by applying the given mapping computation shape 81 // to the given operand shapes. 82 static StatusOr<Shape> InferMapShape( 83 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, 84 absl::Span<const int64> dimensions); 85 86 // Infers the shape produced by InferBatchNormTraining with the given 87 // operands. 88 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape, 89 const Shape& scale_shape, 90 const Shape& offset_shape, 91 int64 feature_index); 92 93 // Infers the shape produced by InferBatchNormInference with the given 94 // operands. 95 static StatusOr<Shape> InferBatchNormInferenceShape( 96 const Shape& operand_shape, const Shape& scale_shape, 97 const Shape& offset_shape, const Shape& mean_shape, 98 const Shape& variance_shape, int64 feature_index); 99 100 // Infers the shape produced by InferBatchNormGrad with the given operands. 101 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape, 102 const Shape& scale_shape, 103 const Shape& mean_shape, 104 const Shape& var_shape, 105 const Shape& output_grad_shape, 106 int64 feature_index); 107 108 // Infers the shape produced by applying the given convolutional filter (rhs) 109 // to lhs in the way specified by the fields on window. An optional 110 // preferred_element_type can be specified to upcast the element type. 111 static StatusOr<Shape> InferConvolveShape( 112 const Shape& lhs, const Shape& rhs, int64 feature_group_count, 113 int64 batch_group_count, const Window& window, 114 const ConvolutionDimensionNumbers& dimension_numbers, 115 absl::optional<PrimitiveType> preferred_element_type); 116 117 // Infers the shape produced by the given FFT type on the given operand. 118 static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, 119 absl::Span<const int64> fft_length); 120 121 // Infers the shape produced by the given triangular solve operation. 122 static StatusOr<Shape> InferTriangularSolveShape( 123 const Shape& a, const Shape& b, const TriangularSolveOptions& options); 124 125 // Infers the shape produced by the given triangular solve operation. 126 static StatusOr<Shape> InferCholeskyShape(const Shape& a); 127 128 // Infers the shape produced by an all-gather with the given operand shape, 129 // concat dimension, and shard count. 130 static StatusOr<Shape> InferAllGatherShape(const Shape& operand_shape, 131 int64 all_gather_dimension, 132 int64 shard_count); 133 134 // Infers the shape produced by a cross replica sum with the given operand 135 // shapes. 136 static StatusOr<Shape> InferAllReduceShape( 137 absl::Span<const Shape* const> operand_shapes); 138 139 // Infers final shape of an Alltoall operation that is created by the xla 140 // builder. 141 static StatusOr<Shape> InferAllToAllShape(const Shape& shape, 142 int64 split_dimension, 143 int64 concat_dimension, 144 int64 split_count); 145 146 // Infers the shape of an HLO all-to-all instruction. 147 static StatusOr<Shape> InferAllToAllTupleShape( 148 absl::Span<const Shape* const> operand_shapes); 149 150 // Infers the shape of a collective permute operation. 151 static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape); 152 153 // Infers the shape produced by applying the given reduction computation 154 // shape to the given input operand shape. 155 // 156 // If pass_index is true, the reduce function is invoked with the element 157 // index as the leading parameter, and the program shape should match 158 // accordingly (or an error will result). 159 static StatusOr<Shape> InferReduceShape( 160 absl::Span<const Shape* const> arg_shapes, 161 absl::Span<const int64> dimensions_to_reduce, 162 const ProgramShape& to_apply); 163 164 // Infers the shape produced by applying the given computation to the operand 165 // shape with the given window and stride dimensions. 166 static StatusOr<Shape> InferReduceWindowShape( 167 const Shape& operand_shape, const Shape& init_value, const Window& window, 168 const ProgramShape& to_apply_shape); 169 static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape, 170 const Shape& init_value, 171 const Window& window); 172 static StatusOr<Shape> InferReduceWindowShape( 173 absl::Span<const Shape* const> operands, 174 absl::Span<const Shape* const> init_values, const Window& window, 175 const ProgramShape& to_apply_shape); 176 177 static StatusOr<Shape> InferReduceWindowShape( 178 absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values, 179 const Window& window); 180 181 // Infers the shape produced by scattering the given source shape to the 182 // selected indices of each window on the operand shape. 183 static StatusOr<Shape> InferSelectAndScatterShape( 184 const Shape& operand_shape, const ProgramShape& select_shape, 185 const Window& window, const Shape& source_shape, 186 const Shape& init_value_shape, const ProgramShape& scatter_shape); 187 188 // Infers the shape produced by a reverse operation that reverses the order 189 // of the elements in the given dimensions. 190 static StatusOr<Shape> InferReverseShape(const Shape& operand_shape, 191 absl::Span<const int64> dimensions); 192 193 // Infers the shape produced by a slice operation spanning from the starts to 194 // the limits in the original shape's dimensions. 195 // 196 // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] 197 static StatusOr<Shape> InferSliceShape(const Shape& arg, 198 absl::Span<const int64> starts, 199 absl::Span<const int64> limits, 200 absl::Span<const int64> strides); 201 202 // Infers the shape produced by a dynamic slice operation of size specified 203 // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. 204 static StatusOr<Shape> InferDynamicSliceShape( 205 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes, 206 absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true); 207 208 // Infers the shape produced by a dynamic update slice operation based 209 // on the shape of operand and update. 210 static StatusOr<Shape> InferDynamicUpdateSliceShape( 211 const Shape& operand_shape, const Shape& update_shape, 212 absl::Span<const Shape> start_index_shapes, 213 bool allow_scalar_indices = true); 214 215 // Infers the shape produced by doing a compile-time-constant indexing into 216 // the given input shape. This is essential for operations on tuples, because 217 // it is impossible to infer the type that comes out of the tuple indexing if 218 // it is not a compile time constant. 219 static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg, 220 int64 index); 221 222 // Infers the shape produced from a while node. condition and body are the 223 // shapes of computations for the condition and the body of a while node, and 224 // init is the shape of data initially passed in to the body as an argument. 225 // The shapes must match; condition: T -> PRED, body: T -> T, init: T 226 static StatusOr<Shape> InferWhileShape(const ProgramShape& condition, 227 const ProgramShape& body, 228 const Shape& init); 229 230 // Infers the shape produced by a predicated or indexed conditional operation. 231 static StatusOr<Shape> InferConditionalShape( 232 const Shape& branch_index, 233 absl::Span<const ProgramShape> branch_computations, 234 absl::Span<const Shape> branch_operands); 235 236 // Infers the shape produced by a broadcast operation. 237 static StatusOr<Shape> InferBroadcastShape( 238 const Shape& operand, absl::Span<const int64> broadcast_sizes); 239 240 // Checks whether the given parameters can form a broadcast. Returns the same 241 // output_shape if it's legal. 242 static StatusOr<Shape> InferBroadcastShape( 243 const Shape& operand_shape, const Shape& output_shape, 244 absl::Span<const int64> broadcast_dimensions); 245 246 // Infers the shape produced by a reshape operation from the element type of 247 // its operand and the new dimension sizes specified. 248 static StatusOr<Shape> InferReshapeShape(const Shape& operand, 249 absl::Span<const int64> dimensions, 250 absl::Span<const int64> new_sizes, 251 int64 inferred_dimension); 252 253 // Infers the shape produced by a dynamic reshape operation from the element 254 // type of its operand and the new dimension sizes specified. The result shape 255 // will have dynamic dimensions as specific in `dim_is_dynamic` and bound 256 // `new_size_bounds`. 257 static StatusOr<Shape> InferDynamicReshapeShape( 258 const Shape& operand, absl::Span<const Shape* const> dim_size_shapes, 259 absl::Span<const int64> new_size_bounds, 260 const std::vector<bool>& dims_are_dynamic); 261 262 // Infers the shape produced by a transpose operation from the element type of 263 // its operand and its dimensions field. 264 static StatusOr<Shape> InferTransposeShape( 265 const Shape& operand, absl::Span<const int64> dimensions); 266 267 // Helper that infers the shape produced by performing a concatenate operation 268 // with the given operand shapes. 269 static StatusOr<Shape> InferConcatOpShape( 270 absl::Span<const Shape* const> arg_shapes, int64 dimension); 271 272 // Helper that validates the given operand shape can be converted to the 273 // target output_shape via a convert instruction -- the requirement is that 274 // the shape is identical except for the element type. 275 static StatusOr<Shape> InferConvertShape(const Shape& operand_shape, 276 PrimitiveType new_element_type); 277 278 // Helper that validates the given operand shape can be bitcast converted to 279 // the target output_shape via a bitcast convert instruction -- the 280 // requirement is that the shape is identical except for the element type and 281 // the element types have identical bit-widths. 282 static StatusOr<Shape> InferBitcastConvertShape( 283 const Shape& operand_shape, PrimitiveType new_element_type); 284 285 // Helper that validates the input data type for a reduce-precision operation, 286 // and returns the result shape. 287 static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape, 288 const int exponent_bits, 289 const int mantissa_bits); 290 291 // Helper that infers the shape produced by a pad operation based on the 292 // padding configuration. 293 static StatusOr<Shape> InferPadShape(const Shape& operand_shape, 294 const Shape& padding_value_shape, 295 const PaddingConfig& padding_config); 296 297 // Helper that validates the given arg_shapes are compatible with the shape of 298 // the to_apply parameters, and returns the to_apply result shape. 299 static StatusOr<Shape> InferCallShape( 300 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply); 301 302 // Helper that infers the shape produced by performing a dot operation with 303 // the given LHS and RHS shapes. An optional preferred_element_type can be 304 // specified to upcast the element type. 305 static StatusOr<Shape> InferDotOpShape( 306 const Shape& lhs, const Shape& rhs, 307 const DotDimensionNumbers& dimension_numbers, 308 absl::optional<PrimitiveType> preferred_element_type); 309 310 // Helper that infers the shape of the tensor produced by a gather operation 311 // with the given input shape, gather indices shape and gather dimension 312 // numbers. 313 static StatusOr<Shape> InferGatherShape( 314 const Shape& input_shape, const Shape& start_indices_shape, 315 const GatherDimensionNumbers& gather_dim_numbers, 316 absl::Span<const int64> slice_sizes); 317 318 // Helper that validates the given input shape, scatter indices shape, updates 319 // shape, and scatter dimension numbers that constitute a scatter operation, 320 // and returns the result shape of the scatter operation. 321 static StatusOr<Shape> InferScatterShape( 322 const Shape& operand_shape, const Shape& scatter_indices_shape, 323 const Shape& updates_shape, const ProgramShape& to_apply_shape, 324 const ScatterDimensionNumbers& scatter_dim_numbers); 325 326 // Helper that validates the given input shape to GetDimensionSize. 327 static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape, 328 int64 dimension); 329 330 // Helper that validates the given input shape to SetDimensionSize. 331 static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape, 332 const Shape& val_shape, 333 int64 dimension); 334 335 // Helper function for creating a Window proto from user-supplied data. 336 // Returns error if the user-supplied data was invalid. 337 static StatusOr<Window> InferWindowFromDimensions( 338 absl::Span<const int64> window_dimensions, 339 absl::Span<const int64> window_strides, 340 absl::Span<const std::pair<int64, int64>> padding, 341 absl::Span<const int64> lhs_dilation, 342 absl::Span<const int64> rhs_dilation); 343 344 private: 345 // Helper that infers the shape produced by performing an element-wise binary 346 // operation with the given LHS and RHS shapes. 347 // Note: By "element-wise" we mean operations that look at a single element in 348 // the LHS and a single element in the RHS to produce a single output element, 349 // even in the presence of broadcasting of one of the operands over the other. 350 static StatusOr<Shape> InferElementwiseBinaryOpShape( 351 HloOpcode operation, const Shape& lhs, const Shape& rhs, 352 absl::Span<const int64> broadcast_dimensions); 353 354 // Helper for inferring the shape of Clamp ops. 355 static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand, 356 const Shape& max); 357 358 // Helper for inferring the shape of Select ops. 359 static StatusOr<Shape> InferSelectShape(const Shape& pred, 360 const Shape& on_true, 361 const Shape& on_false); 362 // Helper for inferring the shape of TupleSelect ops. 363 static StatusOr<Shape> InferTupleSelectShape(const Shape& pred, 364 const Shape& on_true, 365 const Shape& on_false); 366 367 // Helper for inferring shapes of binary operations which use degenerate 368 // dimension broadcasting (a dimension of size 1 in one operand is broadcast 369 // up to match the size of the dimension in the other operand). 370 static StatusOr<Shape> InferDegenerateDimensionBroadcastShape( 371 HloOpcode operation, const Shape& lhs, const Shape& rhs); 372 373 // Helper for inferring shapes of binary operations using "InDim" 374 // broadcasting. This is the broadcasting used in the *InDim binary operations 375 // (for example ComputationBuilder::AddInDim). smaller_shape must be a 376 // lower-rank shape than larger_shape. Returns the shape that the 377 // smaller_shape is broadcast to. 378 static StatusOr<Shape> InferInDimBroadcastShape( 379 const Shape& smaller_shape, const Shape& larger_shape, 380 absl::Span<const int64> broadcast_dimensions); 381 382 TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); 383 }; 384 385 } // namespace xla 386 387 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 388