1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
17
18 #include "llvm/IR/Function.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/Verifier.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28 namespace cpu {
29 namespace runtime {
30
31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
33 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
34 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
35 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
36 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
37
38 namespace {
39
40 // Replaces calls to the function `fn_name` with the code generated by
41 // fn_body_generator.
42 //
43 // We assume that fn_name accepts either a scalar f32 or a vector of
44 // vector_width f32s, and that fn_body_generator generates a function body with
45 // the same inputs/outputs as fn_name.
RewriteCalls(llvm::Module * module,const char * fn_name,std::function<llvm::Value * (llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)> fn_body_generator,int32 vector_width,bool enable_fast_math)46 void RewriteCalls(
47 llvm::Module* module, const char* fn_name,
48 std::function<llvm::Value*(llvm::IRBuilder<>* b, llvm::Value* input,
49 int32 vector_width)>
50 fn_body_generator,
51 int32 vector_width, bool enable_fast_math) {
52 llvm::Function* fn = module->getFunction(fn_name);
53 if (fn == nullptr) {
54 // If the function declaration is not present in the module, there can't be
55 // any calls to resolve. Don't emit the function in this case.
56 return;
57 }
58
59 // Our task is to generate a function body for `fn`, but we can't generate a
60 // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it
61 // with a new function.
62 if (fn->isIntrinsic()) {
63 llvm::Function* new_fn = llvm::Function::Create(
64 fn->getFunctionType(), llvm::GlobalValue::InternalLinkage,
65 llvm::Twine("xla_impl.") + fn_name, module);
66 fn->replaceAllUsesWith(new_fn);
67 fn->eraseFromParent();
68 fn = new_fn;
69 }
70
71 llvm::LLVMContext* context = &module->getContext();
72
73 llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn);
74 llvm::IRBuilder<> b(fn_body);
75 llvm::FastMathFlags fast_math_flags;
76 fast_math_flags.setFast(enable_fast_math);
77 b.setFastMathFlags(fast_math_flags);
78
79 llvm::Value* input = &*fn->arg_begin();
80
81 // Upcast to vector type if input is a scalar.
82 if (vector_width == 1) {
83 llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1);
84 input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input,
85 uint64_t{0});
86 }
87
88 // Generate the vectorized code.
89 CHECK_EQ(vector_width, input->getType()->getVectorNumElements());
90 llvm::Value* result = fn_body_generator(&b, input, vector_width);
91
92 // Downcast result to scalar type if necessary.
93 if (vector_width == 1) {
94 result = b.CreateExtractElement(result, uint64_t{0});
95 }
96 b.CreateRet(result);
97 DCHECK(!llvm::verifyFunction(*fn));
98
99 // Force-inline `fn` into all of its callers and then delete `fn`.
100 //
101 // TODO(b/73081976): Should we avoid inlining these in some cases?
102 std::vector<llvm::CallInst*> calls_to_inline;
103 for (auto* user : fn->users()) {
104 calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
105 }
106 for (auto* call_to_inline : calls_to_inline) {
107 llvm::InlineFunctionInfo inline_function_info;
108 CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
109 }
110 fn->eraseFromParent();
111 }
112
GenerateVF32Tanh(llvm::IRBuilder<> * b,llvm::Value * input,int32)113 llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input,
114 int32 /*vector_width*/) {
115 return llvm_ir::EmitFastTanh(b, input);
116 }
117
GenerateVF32Exp(llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)118 llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
119 int32 vector_width) {
120 VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
121
122 // This implements the same polynomial approximation as implemented in Eigen3.
123
124 const llvm::APFloat half = GetIeeeF32(0.5);
125 const llvm::APFloat one = GetIeeeF32(1.0);
126
127 const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
128 const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
129
130 const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
131 const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
132 const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
133
134 const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
135 const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
136 const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
137 const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
138 const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
139 const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
140
141 llvm::Value* input_clamped =
142 vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
143 llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
144 llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
145 llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
146 llvm::Value* x = vsl.Sub(input_clamped, tmp);
147 x = vsl.Sub(x, z);
148 z = vsl.Mul(x, x);
149
150 llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
151 y = vsl.MulAdd(y, x, cephes_exp_p2);
152 y = vsl.MulAdd(y, x, cephes_exp_p3);
153 y = vsl.MulAdd(y, x, cephes_exp_p4);
154 y = vsl.MulAdd(y, x, cephes_exp_p5);
155 y = vsl.MulAdd(y, z, x);
156 y = vsl.Add(one, y);
157
158 // VectorSupportLibrary (intentionally) can't juggle more than one type at a
159 // time so drop down to IRBuilder for this bit.
160 llvm::Value* vector_constant_0x7f =
161 b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
162 llvm::Value* vector_constant_23 =
163 b->CreateVectorSplat(vector_width, b->getInt32(23));
164 llvm::Type* i32_vector_type =
165 llvm::VectorType::get(b->getInt32Ty(), vector_width);
166 // fx is clamped so we don't have to worry about it being out of range for
167 // i32.
168 llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type);
169 emm0 = b->CreateAdd(emm0, vector_constant_0x7f);
170 emm0 = b->CreateShl(emm0, vector_constant_23);
171 llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type());
172
173 return vsl.Max(vsl.Mul(y, emm0_f32), input);
174 }
175
GenerateVF32Log(llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)176 llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
177 int32 vector_width) {
178 VectorSupportLibrary vsl(F32, vector_width, b, "log_f32");
179
180 const llvm::APFloat half = GetIeeeF32(0.5);
181 const llvm::APFloat one = GetIeeeF32(1.0);
182
183 // This implements the same polynomial approximation as implemented in Eigen3.
184 // Returns NaN for x < 0, -INF for x = 0
185 const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
186 const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
187 const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
188 const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
189 const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
190 const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
191 const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
192 const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
193 const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
194 const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
195 const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
196 const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
197
198 // The smallest non denormalized float number.
199 const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
200 const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
201 const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000);
202 const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
203
204 // invalid_mask is set if x is negative or NaN (and therefore output
205 // must be NaN).
206 llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
207 llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
208 llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
209
210 // Cut off denormalized stuff.
211 llvm::Value* tmp0 = vsl.Max(min_norm_pos, input);
212
213 // VectorSupportLibrary (intentionally) can't juggle more than one type at a
214 // time so drop down to IRBuilder for this bit.
215 llvm::Value* vector_constant_0x7f =
216 b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
217 llvm::Value* vector_constant_23 =
218 b->CreateVectorSplat(vector_width, b->getInt32(23));
219 llvm::Type* i32_vector_type =
220 llvm::VectorType::get(b->getInt32Ty(), vector_width);
221
222 llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type),
223 vector_constant_23);
224
225 // Keep only the fractional part.
226 tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask);
227 tmp0 = vsl.FloatOr(tmp0, half);
228
229 emm0 = b->CreateSub(emm0, vector_constant_0x7f);
230 llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type()));
231
232 // part2:
233 // if( x < SQRTHF ) {
234 // e -= 1;
235 // x = x + x - 1.0;
236 // } else { x = x - 1.0; }
237 llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF);
238 llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask);
239 tmp0 = vsl.Sub(tmp0, one);
240 e = vsl.Sub(e, vsl.FloatAnd(mask, one));
241 tmp0 = vsl.Add(tmp0, tmp1);
242
243 llvm::Value* x2 = vsl.Mul(tmp0, tmp0);
244 llvm::Value* x3 = vsl.Mul(x2, tmp0);
245
246 llvm::Value *y, *y1, *y2;
247 y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1);
248 y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4);
249 y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7);
250 y = vsl.MulAdd(y, tmp0, cephes_log_p2);
251 y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5);
252 y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8);
253 y = vsl.MulAdd(y, x3, y1);
254 y = vsl.MulAdd(y, x3, y2);
255 y = vsl.Mul(y, x3);
256
257 y1 = vsl.Mul(cephes_log_q1, e);
258 llvm::Value* tmp2 = vsl.Mul(half, x2);
259 y = vsl.Add(y, y1);
260 tmp0 = vsl.Sub(tmp0, tmp2);
261 y2 = vsl.Mul(cephes_log_q2, e);
262 tmp0 = vsl.Add(tmp0, y);
263 tmp0 = vsl.Add(tmp0, y2);
264
265 // Contains +/-inf where +/-inf is the correct answer, otherwise 0.
266 llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf),
267 vsl.FloatAnd(is_pos_inf_mask, pos_inf));
268
269 // Contains a finite result or nan. This is the correct answer only if both
270 // result_minus_inf and result_pos_inf are both 0.
271 //
272 // (This implementation works because 0xffffffff is a nan.)
273 llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask);
274
275 // Combine the above into a final result.
276 return vsl.FloatOr(result_inf,
277 vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask),
278 result_finite_or_nan));
279 }
280 } // namespace
281
RewriteIRRuntimeFunctions(llvm::Module * module,bool enable_fast_math)282 void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
283 // Curry some params to RewriteCalls.
284 auto rewrite_calls =
285 std::bind(RewriteCalls, module, std::placeholders::_1,
286 std::placeholders::_2, std::placeholders::_3, enable_fast_math);
287
288 rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1);
289 rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1);
290 rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4);
291 rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
292
293 rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
294 rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
295 rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4);
296 rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8);
297
298 rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1);
299 rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1);
300 rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4);
301 rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8);
302 }
303
304 } // namespace runtime
305 } // namespace cpu
306 } // namespace xla
307