1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17
18 #include <stddef.h>
19
20 #include <unordered_map>
21 #include <vector>
22
23 #include "llvm/IR/DerivedTypes.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/types.h"
26 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
27 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Type.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/status_macros.h"
46 #include "tensorflow/compiler/xla/statusor.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/compiler/xla/window_util.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51
52 namespace xla {
53 namespace gpu {
54
55 using absl::StrAppend;
56 using llvm_ir::IrArray;
57 using llvm_ir::IrName;
58 using llvm_ir::SetToFirstInsertPoint;
59
60 namespace {
61 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)62 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
63 if (operand->opcode() == HloOpcode::kConstant &&
64 operand->literal().IsAllFloat(value)) {
65 return true;
66 }
67 return operand->opcode() == HloOpcode::kBroadcast &&
68 IsFPLiteralWithValue(operand->operand(0), value);
69 }
70 } // namespace
71
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)72 GpuElementalIrEmitter::GpuElementalIrEmitter(
73 const HloModuleConfig& hlo_module_config, llvm::Module* module,
74 llvm::IRBuilder<>* b, NestedComputer compute_nested)
75 : ElementalIrEmitter(module, b),
76 hlo_module_config_(hlo_module_config),
77 compute_nested_(std::move(compute_nested)) {}
78
EmitDeviceMathCall(TargetDeviceFunctionID funcid,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)79 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
80 TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
81 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
82 absl::string_view name) {
83 // Device functions dont have f16 math functions, so we convert the operands
84 // to f32 before calling the function and then convert the result back to f16.
85 bool cast_result_to_fp16 = false;
86 std::vector<llvm::Value*> converted_operands(operands.begin(),
87 operands.end());
88 std::vector<PrimitiveType> converted_input_types(input_types.begin(),
89 input_types.end());
90 switch (output_type) {
91 case F16:
92 cast_result_to_fp16 = true;
93 for (int64 i = 0; i < operands.size(); ++i) {
94 if (input_types[i] == F16) {
95 converted_operands[i] =
96 FPCast(converted_operands[i], b()->getFloatTy());
97 converted_input_types[i] = F32;
98 }
99 }
100 output_type = F32;
101 TF_FALLTHROUGH_INTENDED;
102 case F32:
103 break;
104 case F64:
105 break;
106 default:
107 return Unimplemented("Bad type for device math call: %s",
108 PrimitiveType_Name(output_type));
109 }
110 const string& munged_callee =
111 ObtainDeviceFunctionName(funcid, output_type, b());
112 llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
113 converted_input_types, output_type, name)
114 .ValueOrDie();
115 if (cast_result_to_fp16) {
116 result = FPCast(result, b()->getHalfTy());
117 }
118 return result;
119 }
120
EmitLlvmIntrinsicMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)121 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
122 const string& callee_name, absl::Span<llvm::Value* const> operands,
123 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
124 // llvm intrinsics differentiate between half/float/double functions via
125 // the suffixes ".f16", ".f32" and ".f64".
126 string munged_callee = callee_name;
127 switch (output_type) {
128 case F16:
129 StrAppend(&munged_callee, ".f16");
130 break;
131 case F32:
132 StrAppend(&munged_callee, ".f32");
133 break;
134 case F64:
135 StrAppend(&munged_callee, ".f64");
136 break;
137 default:
138 return Unimplemented("Bad type for llvm intrinsic math call: %s",
139 PrimitiveType_Name(output_type));
140 }
141 return EmitMathCall(munged_callee, operands, input_types, output_type);
142 }
143
EmitMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)144 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
145 const string& callee_name, absl::Span<llvm::Value* const> operands,
146 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
147 absl::string_view name) {
148 // Binary math functions transform are of type [T] -> T.
149 for (PrimitiveType input_type : input_types) {
150 if (output_type != input_type) {
151 return Unimplemented("Input type != output type: %s != %s",
152 PrimitiveType_Name(input_type),
153 PrimitiveType_Name(output_type));
154 }
155 }
156
157 return EmitDeviceFunctionCall(
158 callee_name, operands, input_types, output_type,
159 {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
160 }
161
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)162 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
163 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
164 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
165 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
166 PrimitiveType output_type = op->shape().element_type();
167 HloOpcode opcode = op->opcode();
168
169 if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
170 (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
171 return llvm_ir::EmitCallToIntrinsic(
172 opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
173 : llvm::Intrinsic::minnum,
174 {lhs_value, rhs_value}, {lhs_value->getType()}, b());
175 }
176
177 switch (op->opcode()) {
178 case HloOpcode::kRemainder: {
179 return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod,
180 {lhs_value, rhs_value},
181 {lhs_input_type, rhs_input_type}, output_type);
182 }
183 case HloOpcode::kPower: {
184 return EmitPowerOp(op, lhs_value, rhs_value);
185 }
186 default:
187 return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
188 }
189 }
190
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)191 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
192 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
193 CHECK_EQ(op->opcode(), HloOpcode::kPower);
194 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
195 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
196 PrimitiveType output_type = op->shape().element_type();
197 return EmitDeviceMathCall(TargetDeviceFunctionID::kPow,
198 {lhs_value, rhs_value},
199 {lhs_input_type, rhs_input_type}, output_type);
200 }
201
EmitLog(PrimitiveType prim_type,llvm::Value * value)202 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
203 llvm::Value* value) {
204 return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type},
205 prim_type);
206 }
207
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)208 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
209 llvm::Value* value) {
210 return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value},
211 {prim_type}, prim_type);
212 }
213
EmitSin(PrimitiveType prim_type,llvm::Value * value)214 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
215 llvm::Value* value) {
216 return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type},
217 prim_type);
218 }
219
EmitCos(PrimitiveType prim_type,llvm::Value * value)220 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
221 llvm::Value* value) {
222 return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type},
223 prim_type);
224 }
225
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view)226 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
227 PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
228 return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
229 prim_type);
230 }
231
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)232 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
233 llvm::Value* value) {
234 return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value},
235 {prim_type}, prim_type);
236 }
237
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)238 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
239 llvm::Value* lhs,
240 llvm::Value* rhs,
241 absl::string_view name) {
242 return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
243 {prim_type, prim_type}, prim_type, name);
244 }
245
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)246 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
247 llvm::Value* value) {
248 return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type},
249 prim_type);
250 }
251
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)252 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
253 llvm::Value* value) {
254 return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value},
255 {prim_type}, prim_type);
256 }
257
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)258 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
259 PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
260 absl::string_view name) {
261 return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
262 {prim_type, prim_type}, prim_type, name);
263 }
264
EmitTanh(PrimitiveType prim_type,llvm::Value * value)265 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
266 llvm::Value* value) {
267 // When F64 is being requested, assume performance is less important and use
268 // the more numerically precise tanh function.
269 if (prim_type == F64) {
270 return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value},
271 {prim_type}, prim_type);
272 }
273
274 // Emit a fast approximation of tanh instead of calling __nv_tanh.
275 // __nv_tanh is particularly bad because it contains branches, thus
276 // preventing LLVM's load-store vectorizer from working its magic across a
277 // function which contains tanh calls.
278 //
279 // This routine isn't numerically precise, but it's good enough for ML.
280
281 // Upcast F16 to F32 if necessary.
282 llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
283 llvm::Value* input = FPCast(value, type);
284
285 // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
286 constexpr double kMaxValue = 20.0;
287 auto max_value = llvm::ConstantFP::get(type, kMaxValue);
288 llvm::Value* abs_value =
289 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
290
291 llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
292 auto one = llvm::ConstantFP::get(type, 1.0);
293 auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
294 {one, input}, {type}, b());
295 return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
296 value->getType(), "tanh");
297 }
298
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * value)299 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
300 PrimitiveType prim_type, llvm::Value* value) {
301 return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot,
302 {EmitExtractReal(value), EmitExtractImag(value)},
303 {prim_type, prim_type}, prim_type);
304 }
305
EmitThreadId()306 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
307 llvm::Value* block_id = IntCast(
308 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
309 b()->getIntNTy(128), /*isSigned=*/true, "block.id");
310 llvm::Value* thread_id_in_block = IntCast(
311 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
312 b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
313 llvm::Value* threads_per_block = IntCast(
314 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
315 b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
316 return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
317 }
318
319 } // namespace gpu
320 } // namespace xla
321