1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
17
18 #include <stddef.h>
19
20 #include <unordered_map>
21 #include <vector>
22
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/types.h"
25 // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
26 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/DerivedTypes.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Type.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
39 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
43 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
44 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/statusor.h"
48 #include "tensorflow/compiler/xla/types.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/compiler/xla/window_util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52
53 namespace xla {
54 namespace gpu {
55
56 using absl::StrAppend;
57 using llvm_ir::IrArray;
58 using llvm_ir::IrName;
59 using llvm_ir::SetToFirstInsertPoint;
60
61 namespace {
62 // Returns whether operand is a floating-point literal with the given value.
IsFPLiteralWithValue(const HloInstruction * operand,float value)63 bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
64 if (operand->opcode() == HloOpcode::kConstant &&
65 operand->literal().IsAllFloat(value)) {
66 return true;
67 }
68 return operand->opcode() == HloOpcode::kBroadcast &&
69 IsFPLiteralWithValue(operand->operand(0), value);
70 }
71 } // namespace
72
GpuElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b,NestedComputer compute_nested)73 GpuElementalIrEmitter::GpuElementalIrEmitter(
74 const HloModuleConfig& hlo_module_config, llvm::Module* module,
75 llvm::IRBuilder<>* b, NestedComputer compute_nested)
76 : ElementalIrEmitter(module, b),
77 hlo_module_config_(hlo_module_config),
78 compute_nested_(std::move(compute_nested)) {}
79
EmitDeviceMathCall(TargetDeviceFunctionID funcid,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)80 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
81 TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
82 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
83 absl::string_view name) {
84 // Device functions dont have f16 math functions, so we convert the operands
85 // to f32 before calling the function and then convert the result back to f16.
86 bool cast_result_to_fp16 = false;
87 std::vector<llvm::Value*> converted_operands(operands.begin(),
88 operands.end());
89 std::vector<PrimitiveType> converted_input_types(input_types.begin(),
90 input_types.end());
91 switch (output_type) {
92 case F16:
93 cast_result_to_fp16 = true;
94 for (int64_t i = 0; i < operands.size(); ++i) {
95 if (input_types[i] == F16) {
96 converted_operands[i] =
97 FPCast(converted_operands[i], b()->getFloatTy());
98 converted_input_types[i] = F32;
99 }
100 }
101 output_type = F32;
102 TF_FALLTHROUGH_INTENDED;
103 case F32:
104 break;
105 case F64:
106 break;
107 default:
108 return Unimplemented("Bad type for device math call: %s",
109 PrimitiveType_Name(output_type));
110 }
111 const string& munged_callee =
112 ObtainDeviceFunctionName(funcid, output_type, b());
113 llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
114 converted_input_types, output_type, name)
115 .ValueOrDie();
116 if (cast_result_to_fp16) {
117 result = FPCast(result, b()->getHalfTy());
118 }
119 return result;
120 }
121
EmitLlvmIntrinsicMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type)122 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
123 const string& callee_name, absl::Span<llvm::Value* const> operands,
124 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
125 // llvm intrinsics differentiate between half/float/double functions via
126 // the suffixes ".f16", ".f32" and ".f64".
127 string munged_callee = callee_name;
128 switch (output_type) {
129 case F16:
130 StrAppend(&munged_callee, ".f16");
131 break;
132 case F32:
133 StrAppend(&munged_callee, ".f32");
134 break;
135 case F64:
136 StrAppend(&munged_callee, ".f64");
137 break;
138 default:
139 return Unimplemented("Bad type for llvm intrinsic math call: %s",
140 PrimitiveType_Name(output_type));
141 }
142 return EmitMathCall(munged_callee, operands, input_types, output_type);
143 }
144
EmitMathCall(const string & callee_name,absl::Span<llvm::Value * const> operands,absl::Span<const PrimitiveType> input_types,PrimitiveType output_type,absl::string_view name)145 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
146 const string& callee_name, absl::Span<llvm::Value* const> operands,
147 absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
148 absl::string_view name) {
149 // Binary math functions transform are of type [T] -> T.
150 for (PrimitiveType input_type : input_types) {
151 if (output_type != input_type) {
152 return Unimplemented("Input type != output type: %s != %s",
153 PrimitiveType_Name(input_type),
154 PrimitiveType_Name(output_type));
155 }
156 }
157
158 return EmitDeviceFunctionCall(
159 callee_name, operands, input_types, output_type,
160 {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
161 }
162
GetSourceIndexOfBitcast(const llvm_ir::IrArray::Index & index,const HloInstruction * hlo)163 llvm_ir::IrArray::Index GpuElementalIrEmitter::GetSourceIndexOfBitcast(
164 const llvm_ir::IrArray::Index& index, const HloInstruction* hlo) {
165 Shape shape = hlo->shape();
166 Shape operand_shape = hlo->operand(0)->shape();
167
168 // Decode the layout of the shape from the Protobugs attached to
169 // backend_config_.
170 BitcastBackendConfig bitcast_config;
171 CHECK(bitcast_config.ParseFromString(hlo->raw_backend_config_string()));
172
173 *shape.mutable_layout() =
174 xla::Layout::CreateFromProto(bitcast_config.result_layout());
175 *operand_shape.mutable_layout() =
176 xla::Layout::CreateFromProto(bitcast_config.source_layout());
177 return index.SourceIndexOfBitcast(shape, operand_shape, b());
178 }
179
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)180 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
181 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
182 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
183 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
184 PrimitiveType output_type = op->shape().element_type();
185 HloOpcode opcode = op->opcode();
186
187 if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() &&
188 (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) {
189 return llvm_ir::EmitCallToIntrinsic(
190 opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
191 : llvm::Intrinsic::minnum,
192 {lhs_value, rhs_value}, {lhs_value->getType()}, b());
193 }
194
195 switch (op->opcode()) {
196 case HloOpcode::kRemainder: {
197 return EmitDeviceMathCall(TargetDeviceFunctionID::kFmod,
198 {lhs_value, rhs_value},
199 {lhs_input_type, rhs_input_type}, output_type);
200 }
201 case HloOpcode::kPower: {
202 return EmitPowerOp(op, lhs_value, rhs_value);
203 }
204 default:
205 return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
206 }
207 }
208
EmitPowerOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)209 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
210 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
211 CHECK_EQ(op->opcode(), HloOpcode::kPower);
212 PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
213 PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
214 PrimitiveType output_type = op->shape().element_type();
215 return EmitDeviceMathCall(TargetDeviceFunctionID::kPow,
216 {lhs_value, rhs_value},
217 {lhs_input_type, rhs_input_type}, output_type);
218 }
219
EmitLog(PrimitiveType prim_type,llvm::Value * value)220 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type,
221 llvm::Value* value) {
222 return EmitDeviceMathCall(TargetDeviceFunctionID::kLog, {value}, {prim_type},
223 prim_type);
224 }
225
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)226 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
227 llvm::Value* value) {
228 return EmitDeviceMathCall(TargetDeviceFunctionID::kLog1p, {value},
229 {prim_type}, prim_type);
230 }
231
EmitSin(PrimitiveType prim_type,llvm::Value * value)232 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type,
233 llvm::Value* value) {
234 return EmitDeviceMathCall(TargetDeviceFunctionID::kSin, {value}, {prim_type},
235 prim_type);
236 }
237
EmitCos(PrimitiveType prim_type,llvm::Value * value)238 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
239 llvm::Value* value) {
240 return EmitDeviceMathCall(TargetDeviceFunctionID::kCos, {value}, {prim_type},
241 prim_type);
242 }
243
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view)244 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
245 PrimitiveType prim_type, llvm::Value* value, absl::string_view /*name*/) {
246 return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
247 prim_type);
248 }
249
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)250 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
251 llvm::Value* value) {
252 return EmitDeviceMathCall(TargetDeviceFunctionID::kExpm1, {value},
253 {prim_type}, prim_type);
254 }
255
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)256 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
257 llvm::Value* lhs,
258 llvm::Value* rhs,
259 absl::string_view name) {
260 return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
261 {prim_type, prim_type}, prim_type, name);
262 }
263
EmitSqrt(PrimitiveType prim_type,llvm::Value * value)264 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
265 llvm::Value* value) {
266 return EmitDeviceMathCall(TargetDeviceFunctionID::kSqrt, {value}, {prim_type},
267 prim_type);
268 }
269
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)270 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
271 llvm::Value* value) {
272 return EmitDeviceMathCall(TargetDeviceFunctionID::kRsqrt, {value},
273 {prim_type}, prim_type);
274 }
275
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)276 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
277 PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
278 absl::string_view name) {
279 return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
280 {prim_type, prim_type}, prim_type, name);
281 }
282
EmitTanh(PrimitiveType prim_type,llvm::Value * value)283 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
284 llvm::Value* value) {
285 // When F64 is being requested, assume performance is less important and use
286 // the more numerically precise tanh function.
287 if (prim_type == F64) {
288 return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value},
289 {prim_type}, prim_type);
290 }
291
292 // Emit a fast approximation of tanh instead of calling __nv_tanh.
293 // __nv_tanh is particularly bad because it contains branches, thus
294 // preventing LLVM's load-store vectorizer from working its magic across a
295 // function which contains tanh calls.
296 //
297 // This routine isn't numerically precise, but it's good enough for ML.
298
299 // Upcast F16 to F32 if necessary.
300 llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
301 llvm::Value* input = FPCast(value, type);
302
303 // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
304 constexpr double kMaxValue = 20.0;
305 auto max_value = llvm::ConstantFP::get(type, kMaxValue);
306 llvm::Value* abs_value =
307 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
308
309 llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
310 auto one = llvm::ConstantFP::get(type, 1.0);
311 auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
312 {one, input}, {type}, b());
313 return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
314 value->getType(), "tanh");
315 }
316
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * value)317 StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
318 PrimitiveType prim_type, llvm::Value* value) {
319 return EmitDeviceMathCall(TargetDeviceFunctionID::kHypot,
320 {EmitExtractReal(value), EmitExtractImag(value)},
321 {prim_type, prim_type}, prim_type);
322 }
323
EmitThreadId()324 llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
325 llvm::Value* block_id = IntCast(
326 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
327 b()->getIntNTy(128), /*isSigned=*/true, "block.id");
328 llvm::Value* thread_id_in_block = IntCast(
329 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
330 b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
331 llvm::Value* threads_per_block = IntCast(
332 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
333 b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
334 return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
335 }
336
337 } // namespace gpu
338 } // namespace xla
339