• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17 
18 #include <stddef.h>
19 
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "llvm/IR/DerivedTypes.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/types.h"
26 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
27 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Type.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/window_util.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51 
52 namespace xla {
53 namespace gpu {
54 
55 using absl::StrAppend;
56 using llvm_ir::IrArray;
57 using llvm_ir::IrName;
58 using llvm_ir::SetToFirstInsertPoint;
59 
60 namespace {
61 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)62 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
63   if (operand->opcode() == HloOpcode::kConstant &&
64       operand->literal().IsAllFloat(value)) {
65     return true;
66   }
67   return operand->opcode() == HloOpcode::kBroadcast &&
68          IsFPLiteralWithValue(operand->operand(0), value);
69 }
70 }  // namespace
71 
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)72 GpuElementalIrEmitter::GpuElementalIrEmitter(
73     const HloModuleConfig& hlo_module_config, llvm::Module* module,
74     llvm::IRBuilder<>* b, NestedComputer compute_nested)
75     : ElementalIrEmitter(module, b),
76       hlo_module_config_(hlo_module_config),
77       compute_nested_(std::move(compute_nested)) {}
78 
EmitDeviceMathCall(TargetDeviceFunctionID funcid,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)79 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
80     TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
81     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
82     absl::string_view name) {
83   // Device functions dont have f16 math functions, so we convert the operands
84   // to f32 before calling the function and then convert the result back to f16.
85   bool cast_result_to_fp16 = false;
86   std::vector<llvm::Value*> converted_operands(operands.begin(),
87                                                operands.end());
88   std::vector<PrimitiveType> converted_input_types(input_types.begin(),
89                                                    input_types.end());
90   switch (output_type) {
91     case F16:
92       cast_result_to_fp16 = true;
93       for (int64 i = 0; i < operands.size(); ++i) {
94         if (input_types[i] == F16) {
95           converted_operands[i] =
96               FPCast(converted_operands[i], b()->getFloatTy());
97           converted_input_types[i] = F32;
98         }
99       }
100       output_type = F32;
101       TF_FALLTHROUGH_INTENDED;
102     case F32:
103       break;
104     case F64:
105       break;
106     default:
107       return Unimplemented("Bad type for device math call: %s",
108                            PrimitiveType_Name(output_type));
109   }
110   const string& munged_callee =
111       ObtainDeviceFunctionName(funcid, output_type, b());
112   llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
113                                      converted_input_types, output_type, name)
114                             .ValueOrDie();
115   if (cast_result_to_fp16) {
116     result = FPCast(result, b()->getHalfTy());
117   }
118   return result;
119 }
120 
EmitLlvmIntrinsicMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)121 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
122     const string& callee_name, absl::Span<llvm::Value* const> operands,
123     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
124   // llvm intrinsics differentiate between half/float/double functions via
125   // the suffixes ".f16", ".f32" and ".f64".
126   string munged_callee = callee_name;
127   switch (output_type) {
128     case F16:
129       StrAppend(&munged_callee, ".f16");
130       break;
131     case F32:
132       StrAppend(&munged_callee, ".f32");
133       break;
134     case F64:
135       StrAppend(&munged_callee, ".f64");
136       break;
137     default:
138       return Unimplemented("Bad type for llvm intrinsic math call: %s",
139                            PrimitiveType_Name(output_type));
140   }
141   return EmitMathCall(munged_callee, operands, input_types, output_type);
142 }
143 
EmitMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)144 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
145     const string& callee_name, absl::Span<llvm::Value* const> operands,
146     absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
147     absl::string_view name) {
148   // Binary math functions transform are of type [T] -> T.
149   for (PrimitiveType input_type : input_types) {
150     if (output_type != input_type) {
151       return Unimplemented("Input type != output type: %s != %s",
152                            PrimitiveType_Name(input_type),
153                            PrimitiveType_Name(output_type));
154     }
155   }
156 
157   return EmitDeviceFunctionCall(
158       callee_name, operands, input_types, output_type,
159       {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
160 }
161 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)162 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
163     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
164   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
165   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
166   PrimitiveType output_type = op->shape().element_type();
167   HloOpcode opcode = op->opcode();
168 
169   if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
170       (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
171     return llvm_ir::EmitCallToIntrinsic(
172         opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
173                                       : llvm::Intrinsic::minnum,
174         {lhs_value, rhs_value}, {lhs_value->getType()}, b());
175   }
176 
177   switch (op->opcode()) {
178     case HloOpcode::kRemainder: {
179       return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod,
180                                 {lhs_value, rhs_value},
181                                 {lhs_input_type, rhs_input_type}, output_type);
182     }
183     case HloOpcode::kPower: {
184       return EmitPowerOp(op, lhs_value, rhs_value);
185     }
186     default:
187       return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
188   }
189 }
190 
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)191 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
192     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
193   CHECK_EQ(op->opcode(), HloOpcode::kPower);
194   PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
195   PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
196   PrimitiveType output_type = op->shape().element_type();
197   return EmitDeviceMathCall(TargetDeviceFunctionID::kPow,
198                             {lhs_value, rhs_value},
199                             {lhs_input_type, rhs_input_type}, output_type);
200 }
201 
EmitLog(PrimitiveType prim_type,llvm::Value * value)202 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
203                                                       llvm::Value* value) {
204   return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type},
205                             prim_type);
206 }
207 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)208 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
209                                                         llvm::Value* value) {
210   return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value},
211                             {prim_type}, prim_type);
212 }
213 
EmitSin(PrimitiveType prim_type,llvm::Value * value)214 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
215                                                       llvm::Value* value) {
216   return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type},
217                             prim_type);
218 }
219 
EmitCos(PrimitiveType prim_type,llvm::Value * value)220 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
221                                                       llvm::Value* value) {
222   return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type},
223                             prim_type);
224 }
225 
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view)226 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
227     PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
228   return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
229                             prim_type);
230 }
231 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)232 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
233                                                         llvm::Value* value) {
234   return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value},
235                             {prim_type}, prim_type);
236 }
237 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)238 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
239                                                       llvm::Value* lhs,
240                                                       llvm::Value* rhs,
241                                                       absl::string_view name) {
242   return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
243                             {prim_type, prim_type}, prim_type, name);
244 }
245 
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)246 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
247                                                        llvm::Value* value) {
248   return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type},
249                             prim_type);
250 }
251 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)252 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
253                                                         llvm::Value* value) {
254   return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value},
255                             {prim_type}, prim_type);
256 }
257 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)258 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
259     PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
260     absl::string_view name) {
261   return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
262                             {prim_type, prim_type}, prim_type, name);
263 }
264 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)265 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
266                                                        llvm::Value* value) {
267   // When F64 is being requested, assume performance is less important and use
268   // the more numerically precise tanh function.
269   if (prim_type == F64) {
270     return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value},
271                               {prim_type}, prim_type);
272   }
273 
274   // Emit a fast approximation of tanh instead of calling __nv_tanh.
275   // __nv_tanh is particularly bad because it contains branches, thus
276   // preventing LLVM's load-store vectorizer from working its magic across a
277   // function which contains tanh calls.
278   //
279   // This routine isn't numerically precise, but it's good enough for ML.
280 
281   // Upcast F16 to F32 if necessary.
282   llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
283   llvm::Value* input = FPCast(value, type);
284 
285   // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
286   constexpr double kMaxValue = 20.0;
287   auto max_value = llvm::ConstantFP::get(type, kMaxValue);
288   llvm::Value* abs_value =
289       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
290 
291   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
292   auto one = llvm::ConstantFP::get(type, 1.0);
293   auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
294                                                     {one, input}, {type}, b());
295   return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
296                 value->getType(), "tanh");
297 }
298 
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * value)299 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
300     PrimitiveType prim_type, llvm::Value* value) {
301   return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot,
302                             {EmitExtractReal(value), EmitExtractImag(value)},
303                             {prim_type, prim_type}, prim_type);
304 }
305 
EmitThreadId()306 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
307   llvm::Value* block_id = IntCast(
308       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
309       b()->getIntNTy(128), /*isSigned=*/true, "block.id");
310   llvm::Value* thread_id_in_block = IntCast(
311       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
312       b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
313   llvm::Value* threads_per_block = IntCast(
314       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
315       b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
316   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
317 }
318 
319 }  // namespace gpu
320 }  // namespace xla
321