• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
18 
19 #include <utility>
20 
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Value.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
26 
27 // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
28 // don't belong in "ir_emission_utils".
29 
30 namespace xla {
31 namespace gpu {
32 
33 // Different types of convolutions supported by cudnn.
34 //
35 // A way to think about these is that a convolution is defined by three arrays
36 // -- the "input", the "filter", and the "output" -- and given any two of these,
37 // we can compute the third.  For example, a backward-input convolution takes as
38 // input a filter and an "output" and produces an "input" such that if one were
39 // to do a forward convolution of "input" using filter, the result would be
40 // something with the same shape as "output".
41 //
42 // This way of thinking is not correct if you look at the values produced. For
43 // example, a backward-input convolution is not actually the mathematical
44 // inverse of a forward convolution.  But it's right as far as the shapes and
45 // "connectivity" (i.e. which elements of the input affect which elements of
46 // the output) are concerned.
47 enum class CudnnConvKind {
48   kForward,            // input  + filter => output
49   kBackwardInput,      // filter + output => input
50   kBackwardFilter,     // input  + output => filter
51   kForwardActivation,  // activation(conv(input, filter) + broadcast(bias) +
52                        // (optionally) side_input) => output
53 };
54 
55 StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
56 
57 StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
58     const HloCustomCallInstruction* instr);
59 
60 StatusOr<se::dnn::DataType> GetDnnDataType(
61     const HloCustomCallInstruction* conv);
62 
63 // Converts a CudnnConvKind value to a string.
64 string CudnnConvKindToString(CudnnConvKind kind);
65 
66 // Matrix multiplication before the rewrite.
67 //
68 // This function should never return "true" on instructions after
69 // GemmRewriter pass has finished.
70 bool IsMatrixMultiplication(const HloInstruction& dot);
71 
72 // Matrix multiplication rewritten into a GEMM custom call.
73 // All matrix multiplications should be rewritten as such custom calls
74 // after a GemmRewriter lowering pass.
75 bool IsCublasGemm(const HloInstruction& hlo);
76 
77 constexpr int64 kWarpSize = 32;
78 
79 // A call to cuBLAS general matrix multiplication API.
80 extern const char* const kGemmCallTarget;
81 
82 // A call to cuDNN for batch normalization is represented as CustomCall HLO with
83 // a call target equal to one of these strings.
84 //
85 // The operands to and outputs of these calls are the same as those of the
86 // corresponding HLOs, except:
87 //
88 //  - epsilon and feature_index are proper operands, at the end of the operands
89 //    list.  They must be HLO constants.
90 //  - The cuDNN forward training call returns inv_stddev =
91 //    1/sqrt(variance + epsilon) in place of plain variance.
92 //  - Similarly, BatchNormGrad accepts inv_stddev in place of the variance
93 //    operand.
94 extern const char* const kCudnnBatchNormForwardInferenceCallTarget;
95 extern const char* const kCudnnBatchNormForwardTrainingCallTarget;
96 extern const char* const kCudnnBatchNormBackwardCallTarget;
97 
98 // Returns true if `hlo` will be implemented as a call to a cuDNN batch
99 // normalization routine.
100 //
101 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
102 // one of the kCudnnBatchNormFoo constants above, but returns *false* for HLOs
103 // with one of the kBatchNorm opcodes, because these are lowered either to a
104 // sequence of generic HLOs or to a cuDNN CustomCall.
105 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
106 
107 // A call to cuDNN for convolution (forward, backward filter, or backward input)
108 // is represented as a CustomCall HLO with a call target equal to one of these
109 // strings.
110 //
111 // These CustomCalls have window() and convolution_dimension_numbers() set like
112 // regular convolution ops.  They have the same LHS and RHS operands, plus two
113 // additional constant operands: an int64 operand for the cudnn algorithm and
114 // a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
115 // algorithm means that the implementation is free to choose the best algorithm
116 // it can.
117 //
118 // These calls output a tuple (conv_result, scratch_memory), where conv_result
119 // is the actual result of the convolution, and scratch_memory is temporary
120 // memory used by cudnn.  Callers shouldn't inspect scratch_memory, as its value
121 // is not well-defined.
122 //
123 // GpuConvRewriter lowers kConvolution HLOs to these custom calls.
124 // When it does so, it chooses algorithm -1 and 0 bytes of scratch space.  Later
125 // on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit
126 // algorithm for each conv and sets the amount of scratch space needed.
127 //
128 // (Representing the scratch memory as an output may seem strange at first, but
129 // it's quite sensible, from a certain point of view.  The scratch buffer is a
130 // location in memory that the conv can write into, but which it can't legally
131 // read from, at least until it's written something first.  But that's exactly
132 // the definition of an output buffer.)
133 extern const char* const kCudnnConvForwardCallTarget;
134 extern const char* const kCudnnConvBackwardInputCallTarget;
135 extern const char* const kCudnnConvBackwardFilterCallTarget;
136 extern const char* const kCudnnConvBiasActivationForwardCallTarget;
137 
138 // Returns true if `hlo` will be implemented as a call to a cuDNN convolution
139 // routine.
140 //
141 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
142 // one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a
143 // kConvolution opcode.
144 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
145 
146 // Returns true if `hlo` will be implemented as a call to a cuSolver routine.
147 //
148 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
149 // one of the kCusolver... constants, but returns *false* for HLOs with
150 // say, a kCholesky opcode.
151 bool IsCustomCallToCusolver(const HloInstruction& hlo);
152 
153 // Cholesky decomposition. Takes a (batched) matrix as input, and returns a
154 // tuple of (result, workspace, info), where result is the result of the
155 // Cholesky decomposition, workspace is scratch space for cuSolver, and info
156 // is a success/failure code per batch element.
157 extern const char* const kCusolverCholeskyCallTarget;
158 
159 // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
160 // or cuDNN convolution.
161 bool ImplementedAsLibraryCall(const HloInstruction& hlo);
162 
163 // Returns true if either the dimensions being reduced or the dimensions being
164 // kept are contiguous in the input of the reduce instruction.
165 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce);
166 
167 // Returns whether unnested_hlo is an input fusion whose root is either a slice
168 // or a tuple of slices. If verify_no_strides is true, returns false unless all
169 // ROOT slices have no strides.
170 bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
171                           bool verify_no_strides = false);
172 
173 struct ReductionDimensions {
174   // Indicates whether the reduction is a row reduction or a column reduction.
175   bool is_row_reduction;
176 
177   // Contains the size of the three contiguous components for
178   // the reduction [depth, height, width] (major-to-minor ordering).
179   //
180   // For row reduction, we do: [D, H, W] -> [D, H].
181   // For column reduction, we do: [D, H, W] -> [D, W].
182   std::array<int64, 3> dimensions;
183 };
184 
185 // Given the input shape and dimensions to reduce for a reduction, returns
186 // ReductionDimensions.
187 //
188 // Prerequisite: the reduction instruction passes the check
189 // IsReductionFromOrToContiguousDimensions, which guarantees either the
190 // dimensions to reduce or the dimensions to keep are consecutive.
191 ReductionDimensions GetReductionKindAndContiguousComponents(
192     const HloInstruction& reduce);
193 
194 // Get tiling per thread for the given reduction in dimensions [D, H, W] per
195 // thread.
196 std::array<int64, 3> GetReductionTiling(
197     const ReductionDimensions& reduction_dimensions);
198 
199 // Emits call to "vprintf" with given format and arguments.
200 llvm::Value* EmitPrintf(absl::string_view fmt,
201                         absl::Span<llvm::Value* const> arguments,
202                         llvm::IRBuilder<>* builder);
203 
204 // Emits code to shuffle data between threads of a warp. This has the same
205 // semantics as the PTX "shfl.sync.down" instruction but works for values that
206 // aren't 32 bits in size. The last operand of the emitted "shfl" is
207 // `kWarpSize - 1`.
208 //
209 // This function emits a "full-warp" shuffle, which all threads of a warp
210 // participate in.  *Do not use this function from a divergent context:* You
211 // can't correctly do so on both Volta and earlier GPUs.
212 //
213 // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
214 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
215                                      llvm::IRBuilder<>* builder);
216 
217 // Emits code that determines whether the current thread is thread 0 within
218 // block 0 of the kernel.
219 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
220 
221 // Returns whether the outputs of a fusion with reduction are consistent.
222 bool AreFusedReductionOutputsConsistent(
223     absl::Span<const HloInstruction* const> output_instructions,
224     const HloInstruction* first_reduce);
225 
226 }  // namespace gpu
227 }  // namespace xla
228 
229 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
230