1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shape inference is used by the XLA service as the user builds up 17 // computation requests. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 21 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace xla { 34 35 // For a given operation and input shapes, infers what the resulting shape is 36 // for the operation. With this functionality, the user does not need to specify 37 // the expected result type for computations that are built up via the API -- 38 // the shape that results from an operation is inferred. Some methods have 39 // overloads for inferring shape at the HLO level. 40 // 41 // TODO(b/73352135): Shape inference does not issue very good error messages, in 42 // part because HloInstruction::ToString() is not available since shape 43 // inference runs before the HloInstruction object is created. We need a 44 // solution for this. 45 class ShapeInference { 46 public: 47 // Infers the shape produced by applying the given unary operation to the 48 // given input shape. 49 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 50 const Shape& shape); 51 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 52 const HloInstruction* operand); 53 54 // Infers the shape produced by applying the given binary operation to the 55 // given input shapes. 56 static StatusOr<Shape> InferBinaryOpShape( 57 HloOpcode opcode, const Shape& lhs, const Shape& rhs, 58 absl::Span<const int64> broadcast_dimensions); 59 static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode, 60 const HloInstruction* lhs, 61 const HloInstruction* rhs); 62 63 // Infers the shape produced by applying the given ternary operation to the 64 // given input shapes. 65 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, 66 const Shape& rhs, 67 const Shape& ehs); 68 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, 69 const HloInstruction* lhs, 70 const HloInstruction* rhs, 71 const HloInstruction* ehs); 72 73 // Infers the shape produced by applying the given variadic operation to the 74 // given input operand shapes. 75 static StatusOr<Shape> InferVariadicOpShape( 76 HloOpcode opcode, absl::Span<const Shape* const> operand_shapes); 77 static StatusOr<Shape> InferVariadicOpShape( 78 HloOpcode opcode, absl::Span<const HloInstruction* const> operands); 79 80 // Infers the shape produced by applying the given mapping computation shape 81 // to the given operand shapes. 82 static StatusOr<Shape> InferMapShape( 83 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply, 84 absl::Span<const int64> dimensions); 85 86 // Infers the shape produced by InferBatchNormTraining with the given 87 // operands. 88 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape, 89 const Shape& scale_shape, 90 const Shape& offset_shape, 91 int64 feature_index); 92 93 // Infers the shape produced by InferBatchNormInference with the given 94 // operands. 95 static StatusOr<Shape> InferBatchNormInferenceShape( 96 const Shape& operand_shape, const Shape& scale_shape, 97 const Shape& offset_shape, const Shape& mean_shape, 98 const Shape& variance_shape, int64 feature_index); 99 100 // Infers the shape produced by InferBatchNormGrad with the given operands. 101 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape, 102 const Shape& scale_shape, 103 const Shape& mean_shape, 104 const Shape& var_shape, 105 const Shape& output_grad_shape, 106 int64 feature_index); 107 108 // Infers the shape produced by applying the given convolutional 109 // filter (rhs) to lhs in the way specified by the fields on window. 110 static StatusOr<Shape> InferConvolveShape( 111 const Shape& lhs, const Shape& rhs, int64 feature_group_count, 112 int64 batch_group_count, const Window& window, 113 const ConvolutionDimensionNumbers& dimension_numbers); 114 115 // Infers the shape produced by the given FFT type on the given operand. 116 static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, 117 absl::Span<const int64> fft_length); 118 119 // Infers the shape produced by the given triangular solve operation. 120 static StatusOr<Shape> InferTriangularSolveShape( 121 const Shape& a, const Shape& b, const TriangularSolveOptions& options); 122 123 // Infers the shape produced by the given triangular solve operation. 124 static StatusOr<Shape> InferCholeskyShape(const Shape& a); 125 126 // Infers the shape produced by a cross replica sum with the given operand 127 // shapes. 128 static StatusOr<Shape> InferAllReduceShape( 129 absl::Span<const Shape* const> operand_shapes); 130 131 // Infers final shape of an Alltoall operation that is created by the xla 132 // builder. 133 static StatusOr<Shape> InferAllToAllShape(const Shape& shape, 134 int64 split_dimension, 135 int64 concat_dimension, 136 int64 split_count); 137 138 // Infers the shape of an HLO all-to-all instruction. 139 static StatusOr<Shape> InferAllToAllTupleShape( 140 absl::Span<const Shape* const> operand_shapes); 141 142 // Infers the shape of a collective permute operation. 143 static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape); 144 145 // Infers the shape produced by applying the given reduction computation 146 // shape to the given input operand shape. 147 // 148 // If pass_index is true, the reduce function is invoked with the element 149 // index as the leading parameter, and the program shape should match 150 // accordingly (or an error will result). 151 static StatusOr<Shape> InferReduceShape( 152 absl::Span<const Shape* const> arg_shapes, 153 absl::Span<const int64> dimensions_to_reduce, 154 const ProgramShape& to_apply); 155 156 // Infers the shape produced by applying the given computation to the operand 157 // shape with the given window and stride dimensions. 158 static StatusOr<Shape> InferReduceWindowShape( 159 const Shape& operand_shape, const Shape& init_value, const Window& window, 160 const ProgramShape& to_apply_shape); 161 162 // Infers the shape produced by scattering the given source shape to the 163 // selected indices of each window on the operand shape. 164 static StatusOr<Shape> InferSelectAndScatterShape( 165 const Shape& operand_shape, const ProgramShape& select_shape, 166 const Window& window, const Shape& source_shape, 167 const Shape& init_value_shape, const ProgramShape& scatter_shape); 168 169 // Infers the shape produced by a reverse operation that reverses the order 170 // of the elements in the given dimensions. 171 static StatusOr<Shape> InferReverseShape(const Shape& operand_shape, 172 absl::Span<const int64> dimensions); 173 174 // Infers the shape produced by a slice operation spanning from the starts to 175 // the limits in the original shape's dimensions. 176 // 177 // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] 178 static StatusOr<Shape> InferSliceShape(const Shape& arg, 179 absl::Span<const int64> starts, 180 absl::Span<const int64> limits, 181 absl::Span<const int64> strides); 182 183 // Infers the shape produced by a dynamic slice operation of size specified 184 // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. 185 static StatusOr<Shape> InferDynamicSliceShape( 186 const Shape& operand_shape, absl::Span<const Shape> start_index_shapes, 187 absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true); 188 189 // Infers the shape produced by a dynamic update slice operation based 190 // on the shape of operand and update. 191 static StatusOr<Shape> InferDynamicUpdateSliceShape( 192 const Shape& operand_shape, const Shape& update_shape, 193 absl::Span<const Shape> start_index_shapes, 194 bool allow_scalar_indices = true); 195 196 // Infers the shape produced by doing a compile-time-constant indexing into 197 // the given input shape. This is essential for operations on tuples, because 198 // it is impossible to infer the type that comes out of the tuple indexing if 199 // it is not a compile time constant. 200 static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg, 201 int64 index); 202 203 // Infers the shape produced from a while node. condition and body are the 204 // shapes of computations for the condition and the body of a while node, and 205 // init is the shape of data initially passed in to the body as an argument. 206 // The shapes must match; condition: T -> PRED, body: T -> T, init: T 207 static StatusOr<Shape> InferWhileShape(const ProgramShape& condition, 208 const ProgramShape& body, 209 const Shape& init); 210 211 // Infers the shape produced by a predicated or indexed conditional operation. 212 static StatusOr<Shape> InferConditionalShape( 213 const Shape& branch_index, 214 absl::Span<const ProgramShape> branch_computations, 215 absl::Span<const Shape> branch_operands); 216 217 // Infers the shape produced by a broadcast operation. 218 static StatusOr<Shape> InferBroadcastShape( 219 const Shape& operand, absl::Span<const int64> broadcast_sizes); 220 221 // Checks whether the given parameters can form a broadcast. Returns the same 222 // output_shape if it's legal. 223 static StatusOr<Shape> InferBroadcastShape( 224 const Shape& operand_shape, const Shape& output_shape, 225 absl::Span<const int64> broadcast_dimensions); 226 227 // Infers the shape produced by a reshape operation from the element type of 228 // its operand and the new dimension sizes specified. 229 static StatusOr<Shape> InferReshapeShape(const Shape& operand, 230 absl::Span<const int64> dimensions, 231 absl::Span<const int64> new_sizes); 232 233 // Infers the shape produced by a transpose operation from the element type of 234 // its operand and its dimensions field. 235 static StatusOr<Shape> InferTransposeShape( 236 const Shape& operand, absl::Span<const int64> dimensions); 237 238 // Helper that infers the shape produced by performing a concatenate operation 239 // with the given operand shapes. 240 static StatusOr<Shape> InferConcatOpShape( 241 absl::Span<const Shape* const> arg_shapes, int64 dimension); 242 243 // Helper that validates the given operand shape can be converted to the 244 // target output_shape via a convert instruction -- the requirement is that 245 // the shape is identical except for the element type. 246 static StatusOr<Shape> InferConvertShape(const Shape& operand_shape, 247 PrimitiveType new_element_type); 248 249 // Helper that validates the given operand shape can be bitcast converted to 250 // the target output_shape via a bitcast convert instruction -- the 251 // requirement is that the shape is identical except for the element type and 252 // the element types have identical bit-widths. 253 static StatusOr<Shape> InferBitcastConvertShape( 254 const Shape& operand_shape, PrimitiveType new_element_type); 255 256 // Helper that validates the input data type for a reduce-precision operation, 257 // and returns the result shape. 258 static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape, 259 const int exponent_bits, 260 const int mantissa_bits); 261 262 // Helper that infers the shape produced by a pad operation based on the 263 // padding configuration. 264 static StatusOr<Shape> InferPadShape(const Shape& operand_shape, 265 const Shape& padding_value_shape, 266 const PaddingConfig& padding_config); 267 268 // Helper that validates the given arg_shapes are compatible with the shape of 269 // the to_apply parameters, and returns the to_apply result shape. 270 static StatusOr<Shape> InferCallShape( 271 absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply); 272 273 // Helper that infers the shape produced by performing a dot operation with 274 // the given LHS and RHS shapes. 275 static StatusOr<Shape> InferDotOpShape( 276 const Shape& lhs, const Shape& rhs, 277 const DotDimensionNumbers& dimension_numbers); 278 279 // Helper that infers the shape of the tensor produced by a gather operation 280 // with the given input shape, gather indices shape and gather dimension 281 // numbers. 282 static StatusOr<Shape> InferGatherShape( 283 const Shape& input_shape, const Shape& start_indices_shape, 284 const GatherDimensionNumbers& gather_dim_numbers, 285 absl::Span<const int64> slice_sizes); 286 287 // Helper that validates the given input shape, scatter indices shape, updates 288 // shape, and scatter dimension numbers that constitute a scatter operation, 289 // and returns the result shape of the scatter operation. 290 static StatusOr<Shape> InferScatterShape( 291 const Shape& operand_shape, const Shape& scatter_indices_shape, 292 const Shape& updates_shape, const ProgramShape& to_apply_shape, 293 const ScatterDimensionNumbers& scatter_dim_numbers); 294 295 static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape, 296 int64 dimension); 297 298 private: 299 // Helper that infers the shape produced by performing an element-wise binary 300 // operation with the given LHS and RHS shapes. 301 // Note: By "element-wise" we mean operations that look at a single element in 302 // the LHS and a single element in the RHS to produce a single output element, 303 // even in the presence of broadcasting of one of the operands over the other. 304 static StatusOr<Shape> InferElementwiseBinaryOpShape( 305 HloOpcode operation, const Shape& lhs, const Shape& rhs, 306 absl::Span<const int64> broadcast_dimensions); 307 308 // Helper for inferring the shape of Clamp ops. 309 static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand, 310 const Shape& max); 311 312 // Helper for inferring the shape of Select ops. 313 static StatusOr<Shape> InferSelectShape(const Shape& pred, 314 const Shape& on_true, 315 const Shape& on_false); 316 // Helper for inferring the shape of TupleSelect ops. 317 static StatusOr<Shape> InferTupleSelectShape(const Shape& pred, 318 const Shape& on_true, 319 const Shape& on_false); 320 321 // Helper for inferring shapes of binary operations which use degenerate 322 // dimension broadcasting (a dimension of size 1 in one operand is broadcast 323 // up to match the size of the dimension in the other operand). 324 static StatusOr<Shape> InferDegenerateDimensionBroadcastShape( 325 HloOpcode operation, const Shape& lhs, const Shape& rhs); 326 327 // Helper for inferring shapes of binary operations using "InDim" 328 // broadcasting. This is the broadcasting used in the *InDim binary operations 329 // (for example ComputationBuilder::AddInDim). smaller_shape must be a 330 // lower-rank shape than larger_shape. Returns the shape that the 331 // smaller_shape is broadcast to. 332 static StatusOr<Shape> InferInDimBroadcastShape( 333 const Shape& smaller_shape, const Shape& larger_shape, 334 absl::Span<const int64> broadcast_dimensions); 335 336 TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); 337 }; 338 339 } // namespace xla 340 341 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 342