1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Shape inference is used by the XLA service as the user builds up 17 // computation requests. 18 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 21 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/lib/gtl/array_slice.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace xla { 34 35 // For a given operation and input shapes, infers what the resulting shape is 36 // for the operation. With this functionality, the user does not need to specify 37 // the expected result type for computations that are built up via the API -- 38 // the shape that results from an operation is inferred. Some methods have 39 // overloads for inferring shape at the HLO level. 40 // 41 // TODO(b/73352135): Shape inference does not issue very good error messages, in 42 // part because HloInstruction::ToString() is not available since shape 43 // inference runs before the HloInstruction object is created. We need a 44 // solution for this. 45 class ShapeInference { 46 public: 47 // Infers the shape produced by applying the given unary operation to the 48 // given input shape. 49 static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation, 50 const Shape& arg); 51 static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode, 52 const HloInstruction* operand); 53 54 // Infers the shape produced by applying the given binary operation to the 55 // given input shapes. 56 static StatusOr<Shape> InferBinaryOpShape( 57 BinaryOperation operation, const Shape& lhs, const Shape& rhs, 58 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); 59 static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode, 60 const HloInstruction* lhs, 61 const HloInstruction* rhs); 62 63 // Infers the shape produced by applying the given ternary operation to the 64 // given input shapes. 65 static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation, 66 const Shape& lhs, const Shape& rhs, 67 const Shape& ehs); 68 static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, 69 const HloInstruction* lhs, 70 const HloInstruction* rhs, 71 const HloInstruction* ehs); 72 73 // Infers the shape produced by applying the given variadic operation to the 74 // given input operand shapes. 75 static StatusOr<Shape> InferVariadicOpShape( 76 VariadicOperation operation, 77 tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); 78 static StatusOr<Shape> InferVariadicOpShape( 79 HloOpcode opcode, 80 tensorflow::gtl::ArraySlice<const HloInstruction*> operands); 81 82 // Infers the shape produced by applying the given mapping computation shape 83 // to the given operand shapes. 84 static StatusOr<Shape> InferMapShape( 85 tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, 86 const ProgramShape& to_apply, 87 tensorflow::gtl::ArraySlice<int64> dimensions); 88 89 // Infers the shape produced by InferBatchNormTraining with the given 90 // operands. 91 static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape, 92 const Shape& scale_shape, 93 const Shape& offset_shape, 94 int64 feature_index); 95 96 // Infers the shape produced by InferBatchNormInference with the given 97 // operands. 98 static StatusOr<Shape> InferBatchNormInferenceShape( 99 const Shape& operand_shape, const Shape& scale_shape, 100 const Shape& offset_shape, const Shape& mean_shape, 101 const Shape& variance_shape, int64 feature_index); 102 103 // Infers the shape produced by InferBatchNormGrad with the given operands. 104 static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape, 105 const Shape& scale_shape, 106 const Shape& mean_shape, 107 const Shape& var_shape, 108 const Shape& output_grad_shape, 109 int64 feature_index); 110 111 // Infers the shape produced by applying the given convolutional 112 // filter (rhs) to lhs in the way specified by the fields on window. 113 static StatusOr<Shape> InferConvolveShape( 114 const Shape& lhs, const Shape& rhs, const Window& window, 115 const ConvolutionDimensionNumbers& dimension_numbers); 116 117 // Infers the shape produced by the given FFT type on the given operand. 118 static StatusOr<Shape> InferFftShape( 119 const Shape& in, FftType fft_type, 120 tensorflow::gtl::ArraySlice<int64> fft_length); 121 122 // Infers the shape produced a cross replica sum with the given operand 123 // shapes. 124 static StatusOr<Shape> InferCrossReplicaSumShape( 125 tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); 126 127 // Infers the shape produced by applying the given reduction computation 128 // shape to the given input operand shape. 129 // 130 // If pass_index is true, the reduce function is invoked with the element 131 // index as the leading parameter, and the program shape should match 132 // accordingly (or an error will result). 133 static StatusOr<Shape> InferReduceShape( 134 const Shape& arg, const Shape& init_value, 135 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, 136 const ProgramShape& to_apply); 137 138 // Infers the shape produced by applying the given computation to the operand 139 // shape with the given window and stride dimensions. 140 static StatusOr<Shape> InferReduceWindowShape( 141 const Shape& operand_shape, const Shape& init_value, const Window& window, 142 const ProgramShape& to_apply_shape); 143 144 // Infers the shape produced by scattering the given source shape to the 145 // selected indices of each window on the operand shape. 146 static StatusOr<Shape> InferSelectAndScatterShape( 147 const Shape& operand_shape, const ProgramShape& select_shape, 148 const Window& window, const Shape& source_shape, 149 const Shape& init_value_shape, const ProgramShape& scatter_shape); 150 151 // Infers the shape produced by a reverse operation that reverses the order 152 // of the elements in the given dimensions. 153 static StatusOr<Shape> InferReverseShape( 154 const Shape& operand_shape, 155 tensorflow::gtl::ArraySlice<int64> dimensions); 156 157 // Infers the shape produced by a slice operation spanning from the starts to 158 // the limits in the original shape's dimensions. 159 // 160 // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] 161 static StatusOr<Shape> InferSliceShape( 162 const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts, 163 tensorflow::gtl::ArraySlice<int64> limits, 164 tensorflow::gtl::ArraySlice<int64> strides); 165 166 // Infers the shape produced by a dynamic slice operation of size specified 167 // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. 168 static StatusOr<Shape> InferDynamicSliceShape( 169 const Shape& operand_shape, const Shape& start_indices_shape, 170 tensorflow::gtl::ArraySlice<int64> slice_sizes); 171 172 // Infers the shape produced by a dynamic update slice operation based 173 // on the shape of operand and update. 174 static StatusOr<Shape> InferDynamicUpdateSliceShape( 175 const Shape& operand_shape, const Shape& update_shape, 176 const Shape& start_indices_shape); 177 178 // Infers the shape produced by doing a compile-time-constant indexing into 179 // the given input shape. This is essential for operations on tuples, because 180 // it is impossible to infer the type that comes out of the tuple indexing if 181 // it is not a compile time constant. 182 static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg, 183 int64 index); 184 185 // Infers the shape produced from a while node. condition and body are the 186 // shapes of computations for the condition and the body of a while node, and 187 // init is the shape of data initially passed in to the body as an argument. 188 // The shapes must match; condition: T -> PRED, body: T -> T, init: T 189 static StatusOr<Shape> InferWhileShape(const ProgramShape& condition, 190 const ProgramShape& body, 191 const Shape& init); 192 193 // Infers the shape produced by a conditional operation. 194 static StatusOr<Shape> InferConditionalShape( 195 const Shape& predicate, const Shape& true_operand, 196 const Shape& false_operand, const ProgramShape& true_computation, 197 const ProgramShape& false_computation); 198 199 // Infers the shape produced by a broadcast operation. 200 static StatusOr<Shape> InferBroadcastShape( 201 const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes); 202 203 // Infers the shape produced by a reshape operation from the element type of 204 // its operand and the new dimension sizes specified. 205 static StatusOr<Shape> InferReshapeShape( 206 const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions, 207 tensorflow::gtl::ArraySlice<int64> new_sizes); 208 209 // Infers the shape produced by a transpose operation from the element type of 210 // its operand and its dimensions field. 211 static StatusOr<Shape> InferTransposeShape( 212 const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions); 213 214 // Helper that infers the shape produced by performing a concatenate operation 215 // with the given operand shapes. 216 static StatusOr<Shape> InferConcatOpShape( 217 tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension); 218 219 // Helper that validates the given operand shape can be converted to the 220 // target output_shape via a convert instruction -- the requirement is that 221 // the shape is identical except for the element type. 222 static StatusOr<Shape> InferConvertShape(const Shape& operand_shape, 223 PrimitiveType new_element_type); 224 225 // Helper that validates the given operand shape can be bitcast converted to 226 // the target output_shape via a bitcast convert instruction -- the 227 // requirement is that the shape is identical except for the element type and 228 // the element types have identical bit-widths. 229 static StatusOr<Shape> InferBitcastConvertShape( 230 const Shape& operand_shape, PrimitiveType new_element_type); 231 232 // Helper that validates the input data type for a reduce-precision operation, 233 // and returns the result shape. 234 static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape, 235 const int exponent_bits, 236 const int mantissa_bits); 237 238 // Helper that infers the shape produced by a pad operation based on the 239 // padding configuration. 240 static StatusOr<Shape> InferPadShape(const Shape& operand_shape, 241 const Shape& padding_value_shape, 242 const PaddingConfig& padding_config); 243 244 // Helper that validates the given arg_shapes are compatible with the shape of 245 // the to_apply parameters, and returns the to_apply result shape. 246 static StatusOr<Shape> InferCallShape( 247 tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, 248 const ProgramShape& to_apply); 249 250 // Helper that infers the shape produced by performing a dot operation with 251 // the given LHS and RHS shapes. 252 static StatusOr<Shape> InferDotOpShape( 253 const Shape& lhs, const Shape& rhs, 254 const DotDimensionNumbers& dimension_numbers); 255 256 // Helper that infers the shape of the tensor produced by a gather operation 257 // with the given input shape, gather indices shape and gather dimension 258 // numbers. 259 static StatusOr<Shape> InferGatherShape( 260 const Shape& input_shape, const Shape& gather_indices_shape, 261 const GatherDimensionNumbers& gather_dim_numbers, 262 tensorflow::gtl::ArraySlice<int64> window_bounds); 263 264 private: 265 // Helper that infers the shape produced by performing an element-wise binary 266 // operation with the given LHS and RHS shapes. 267 // Note: By "element-wise" we mean operations that look at a single element in 268 // the LHS and a single element in the RHS to produce a single output element, 269 // even in the presence of broadcasting of one of the operands over the other. 270 static StatusOr<Shape> InferElementwiseBinaryOpShape( 271 BinaryOperation operation, const Shape& lhs, const Shape& rhs, 272 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); 273 274 // Helper for inferring the shape of Clamp ops. 275 static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand, 276 const Shape& max); 277 278 // Helper for inferring the shape of Select ops. 279 static StatusOr<Shape> InferSelectShape(const Shape& pred, 280 const Shape& on_true, 281 const Shape& on_false); 282 283 // Helper for inferring shapes of binary operations which use degenerate 284 // dimension broadcasting (a dimension of size 1 in one operand is broadcast 285 // up to match the size of the dimension in the other operand). 286 static StatusOr<Shape> InferDegenerateDimensionBroadcastShape( 287 BinaryOperation operation, const Shape& lhs, const Shape& rhs); 288 289 // Helper for inferring shapes of binary operations using "InDim" 290 // broadcasting. This is the broadcasting used in the *InDim binary operations 291 // (for example ComputationBuilder::AddInDim). smaller_shape must be a 292 // lower-rank shape than larger_shape. Returns the shape that the 293 // smaller_shape is broadcast to. 294 static StatusOr<Shape> InferInDimBroadcastShape( 295 BinaryOperation operation, const Shape& smaller_shape, 296 const Shape& larger_shape, 297 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); 298 299 TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); 300 }; 301 302 } // namespace xla 303 304 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_ 305