1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17
18 #include "absl/algorithm/container.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22
23 namespace xla {
24 namespace cpu {
VectorSupportLibrary(PrimitiveType primitive_type,int64_t vector_size,llvm::IRBuilder<> * b,std::string name)25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
26 int64_t vector_size,
27 llvm::IRBuilder<>* b,
28 std::string name)
29 : vector_size_(vector_size),
30 primitive_type_(primitive_type),
31 b_(b),
32 name_(std::move(name)) {
33 scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
34 primitive_type, b_->GetInsertBlock()->getModule());
35 scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
36 vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false);
37 vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
38 }
39
TypeToString(llvm::Type * type)40 static std::string TypeToString(llvm::Type* type) {
41 std::string o;
42 llvm::raw_string_ostream ostream(o);
43 type->print(ostream);
44 return ostream.str();
45 }
46
AssertCorrectTypes(std::initializer_list<llvm::Value * > values)47 void VectorSupportLibrary::AssertCorrectTypes(
48 std::initializer_list<llvm::Value*> values) {
49 for (llvm::Value* v : values) {
50 llvm::Type* type = v->getType();
51 if (type != scalar_type() && type != vector_type()) {
52 LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
53 << TypeToString(vector_type()) << " but got "
54 << TypeToString(type);
55 }
56 }
57 }
58
Mul(llvm::Value * lhs,llvm::Value * rhs)59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
60 AssertCorrectTypes({lhs, rhs});
61 return MulInternal(lhs, rhs);
62 }
63
MulInternal(llvm::Value * lhs,llvm::Value * rhs)64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
65 llvm::Value* rhs) {
66 if (scalar_type_->isFloatingPointTy()) {
67 return b()->CreateFMul(lhs, rhs, name());
68 } else {
69 return b()->CreateMul(lhs, rhs, name());
70 }
71 }
72
Add(llvm::Value * lhs,llvm::Value * rhs)73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
74 AssertCorrectTypes({lhs, rhs});
75 return AddInternal(lhs, rhs);
76 }
77
Sub(llvm::Value * lhs,llvm::Value * rhs)78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
79 AssertCorrectTypes({lhs, rhs});
80 return b()->CreateFSub(lhs, rhs);
81 }
82
Max(llvm::Value * lhs,llvm::Value * rhs,bool enable_fast_min_max)83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
84 bool enable_fast_min_max) {
85 AssertCorrectTypes({lhs, rhs});
86 if (scalar_type_->isFloatingPointTy()) {
87 return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
88 } else {
89 LOG(FATAL) << "Max for integers is unimplemented";
90 }
91 }
92
Floor(llvm::Value * a)93 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
94 AssertCorrectTypes({a});
95 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
96 {a->getType()}, b());
97 }
98
Div(llvm::Value * lhs,llvm::Value * rhs)99 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
100 AssertCorrectTypes({lhs, rhs});
101 if (scalar_type_->isFloatingPointTy()) {
102 return b()->CreateFDiv(lhs, rhs, name());
103 } else {
104 LOG(FATAL) << "Division for integers is unimplemented";
105 }
106 }
107
Clamp(llvm::Value * a,const llvm::APFloat & low,const llvm::APFloat & high)108 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
109 const llvm::APFloat& low,
110 const llvm::APFloat& high) {
111 CHECK(!low.isNaN());
112 CHECK(!high.isNaN());
113 CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
114
115 AssertCorrectTypes({a});
116 llvm::Type* type = a->getType();
117 CHECK(scalar_type_->isFloatingPointTy());
118
119 llvm::Value* low_value = GetConstantFloat(type, low);
120 llvm::Value* high_value = GetConstantFloat(type, high);
121 a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value);
122 a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value);
123 return a;
124 }
125
FCmpEQMask(llvm::Value * lhs,llvm::Value * rhs)126 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
127 llvm::Value* rhs) {
128 AssertCorrectTypes({lhs, rhs});
129 return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
130 }
131
FCmpOLTMask(llvm::Value * lhs,llvm::Value * rhs)132 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
133 llvm::Value* rhs) {
134 AssertCorrectTypes({lhs, rhs});
135 return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
136 }
137
FCmpULEMask(llvm::Value * lhs,llvm::Value * rhs)138 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
139 llvm::Value* rhs) {
140 AssertCorrectTypes({lhs, rhs});
141 return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
142 }
143
I1ToFloat(llvm::Value * i1)144 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
145 bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
146 llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
147 return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
148 is_vector ? vector_type() : scalar_type(), name());
149 }
150
IntegerTypeForFloatSize(bool vector)151 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
152 CHECK(scalar_type()->isFloatingPointTy());
153 const llvm::DataLayout& data_layout =
154 b()->GetInsertBlock()->getModule()->getDataLayout();
155 int64_t float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
156 llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
157 if (vector) {
158 return llvm::VectorType::get(scalar_int_type, vector_size(), false);
159 } else {
160 return scalar_int_type;
161 }
162 }
163
BroadcastScalar(llvm::Value * x)164 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
165 CHECK_EQ(x->getType(), scalar_type());
166 return b()->CreateVectorSplat(vector_size(), x, name());
167 }
168
FloatAnd(llvm::Value * lhs,llvm::Value * rhs)169 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
170 llvm::Value* rhs) {
171 AssertCorrectTypes({lhs, rhs});
172 llvm::Type* int_type =
173 IntegerTypeForFloatSize(lhs->getType() == vector_type());
174 return b()->CreateBitCast(
175 b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
176 b()->CreateBitCast(rhs, int_type, name()), name()),
177 vector_type());
178 }
179
FloatNot(llvm::Value * lhs)180 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
181 AssertCorrectTypes({lhs});
182 llvm::Type* int_type =
183 IntegerTypeForFloatSize(lhs->getType() == vector_type());
184 return b()->CreateBitCast(
185 b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
186 vector_type());
187 }
188
FloatOr(llvm::Value * lhs,llvm::Value * rhs)189 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
190 AssertCorrectTypes({lhs, rhs});
191 llvm::Type* int_type =
192 IntegerTypeForFloatSize(lhs->getType() == vector_type());
193 return b()->CreateBitCast(
194 b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
195 b()->CreateBitCast(rhs, int_type, name()), name()),
196 vector_type(), name());
197 }
198
AddInternal(llvm::Value * lhs,llvm::Value * rhs)199 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
200 llvm::Value* rhs) {
201 if (scalar_type_->isFloatingPointTy()) {
202 return b()->CreateFAdd(lhs, rhs, name());
203 } else {
204 return b()->CreateAdd(lhs, rhs, name());
205 }
206 }
207
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements)208 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
209 llvm::Value* base_pointer, llvm::Value* offset_elements) {
210 if (base_pointer->getType() != scalar_pointer_type()) {
211 base_pointer =
212 b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
213 }
214 return b()->CreateInBoundsGEP(scalar_type(), base_pointer, offset_elements,
215 name());
216 }
217
LoadVector(llvm::Value * pointer)218 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
219 if (pointer->getType() != vector_pointer_type()) {
220 pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
221 }
222 return b()->CreateAlignedLoad(
223 vector_type(), pointer,
224 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)), name());
225 }
226
LoadScalar(llvm::Value * pointer)227 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
228 if (pointer->getType() != scalar_pointer_type()) {
229 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
230 }
231 return b()->CreateAlignedLoad(
232 scalar_type(), pointer,
233 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)), name());
234 }
235
StoreVector(llvm::Value * value,llvm::Value * pointer)236 void VectorSupportLibrary::StoreVector(llvm::Value* value,
237 llvm::Value* pointer) {
238 AssertCorrectTypes({value});
239 if (pointer->getType() != vector_pointer_type()) {
240 pointer = b()->CreateBitCast(pointer, vector_pointer_type());
241 }
242 b()->CreateAlignedStore(
243 value, pointer,
244 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
245 }
246
StoreScalar(llvm::Value * value,llvm::Value * pointer)247 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
248 llvm::Value* pointer) {
249 AssertCorrectTypes({value});
250 if (pointer->getType() != scalar_pointer_type()) {
251 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
252 }
253 b()->CreateAlignedStore(
254 value, pointer,
255 llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
256 }
257
LoadBroadcast(llvm::Value * pointer)258 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
259 if (pointer->getType() != scalar_pointer_type()) {
260 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
261 }
262 return b()->CreateVectorSplat(
263 vector_size(), b()->CreateLoad(scalar_type(), pointer), name());
264 }
265
AddReduce(llvm::Value * vector)266 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
267 llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
268 for (unsigned i = vector_size(); i != 1; i >>= 1) {
269 // On every iteration, we shuffle half of the remaining lanes to the top
270 // half of shuffle, and add two old and the new vector.
271
272 for (unsigned j = 0; j < vector_size(); ++j) {
273 if (j < (i / 2)) {
274 mask[j] = b()->getInt32(i / 2 + j);
275 } else {
276 mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
277 }
278 }
279
280 llvm::Value* half_remaining_lanes =
281 b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
282 llvm::ConstantVector::get(mask), "");
283 vector = Add(vector, half_remaining_lanes);
284 }
285
286 return b()->CreateExtractElement(vector, b()->getInt32(0), name());
287 }
288
AvxStyleHorizontalAdd(llvm::Value * lhs,llvm::Value * rhs)289 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
290 llvm::Value* rhs) {
291 CHECK_EQ(lhs->getType(), vector_type());
292 CHECK_EQ(rhs->getType(), vector_type());
293 CHECK_EQ(vector_size() % 2, 0);
294
295 llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
296
297 // Adding the values shuffled using mask_a and mask_b gives us the
298 // AVX-style horizontal add we want. The masks work as documented
299 // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
300 //
301 // Here are the masks for vector_width() == 8:
302 //
303 // index: |0 |1 |2 | 3 |4 |5 | 6 | 7
304 // --------+--+--+--+---+--+--+---+---
305 // mask_a: |0 |2 |8 |10 |4 |6 |12 |14
306 // mask_b: |1 |3 |9 |11 |5 |7 |13 |16
307 //
308 // So, as an example, the value at lane 3 of the result vector is
309 // the result of adding lane 10 and lane 11 in the combined lhs++rhs
310 // vector, which are the lanes 2 and 3 in the rhs vector.
311 for (int i = 0; i < vector_size(); i += 2) {
312 int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
313 mask_a.push_back(b()->getInt32(increment + i));
314 mask_b.push_back(b()->getInt32(increment + i + 1));
315 }
316 for (int i = 0; i < vector_size(); i += 2) {
317 int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
318 mask_a.push_back(b()->getInt32(increment + i));
319 mask_b.push_back(b()->getInt32(increment + i + 1));
320 }
321
322 llvm::Value* shuffle_0 =
323 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
324 llvm::Value* shuffle_1 =
325 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
326
327 return Add(shuffle_0, shuffle_1);
328 }
329
ExtractLowHalf(llvm::Value * vector)330 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
331 llvm::SmallVector<llvm::Constant*, 32> mask;
332 for (int i = 0; i < vector_size() / 2; i++) {
333 mask.push_back(b()->getInt32(i));
334 }
335
336 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
337 llvm::ConstantVector::get(mask));
338 }
339
ExtractHighHalf(llvm::Value * vector)340 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
341 llvm::SmallVector<llvm::Constant*, 32> mask;
342 for (int i = 0; i < vector_size() / 2; i++) {
343 mask.push_back(b()->getInt32(i + vector_size() / 2));
344 }
345
346 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
347 llvm::ConstantVector::get(mask));
348 }
349
ComputeHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)350 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
351 std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
352 const int x86_avx_vector_elements =
353 TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
354 if (vector_size() == x86_avx_vector_elements &&
355 vectors.size() == x86_avx_vector_elements) {
356 return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
357 }
358
359 std::vector<llvm::Value*> result;
360 std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
361 [this](llvm::Value* vector) { return AddReduce(vector); });
362 if (init_values) {
363 for (int64_t i = 0, e = result.size(); i < e; i++) {
364 result[i] = Add(result[i],
365 b()->CreateExtractElement(init_values, b()->getInt32(i)));
366 }
367 }
368 return result;
369 }
370
371 std::vector<llvm::Value*>
ComputeAvxOptimizedHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)372 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
373 std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
374 // vectors are N llvm vector values, each with N elements.
375 int64_t lane_width = vectors.size();
376
377 while (vectors.size() != 2) {
378 std::vector<llvm::Value*> new_vectors;
379 new_vectors.reserve(vectors.size() / 2);
380 for (int i = 0; i < vectors.size(); i += 2) {
381 new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
382 }
383
384 vectors = std::move(new_vectors);
385 }
386
387 llvm::Value* low =
388 AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
389 if (init_values) {
390 low = AddInternal(ExtractLowHalf(init_values), low);
391 }
392 llvm::Value* high =
393 AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
394 if (init_values) {
395 high = AddInternal(ExtractHighHalf(init_values), high);
396 }
397
398 // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
399 // the next `lane_width / 2` horizontal reductions.
400
401 std::vector<llvm::Value*> results;
402 for (int i = 0; i < lane_width; i++) {
403 llvm::Value* scalar_result =
404 b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
405 b()->getInt32(i % (lane_width / 2)), name());
406 results.push_back(scalar_result);
407 }
408
409 return results;
410 }
411
GetZeroVector()412 llvm::Value* VectorSupportLibrary::GetZeroVector() {
413 return llvm::Constant::getNullValue(vector_type());
414 }
415
GetZeroScalar()416 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
417 return llvm::Constant::getNullValue(scalar_type());
418 }
419
LlvmVariable(llvm::Type * type,llvm::IRBuilder<> * b)420 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
421 alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
422 }
423
Get() const424 llvm::Value* LlvmVariable::Get() const {
425 return b_->CreateLoad(alloca_->getAllocatedType(), alloca_);
426 }
427
Set(llvm::Value * new_value)428 void LlvmVariable::Set(llvm::Value* new_value) {
429 b_->CreateStore(new_value, alloca_);
430 }
431
TileVariable(VectorSupportLibrary * vector_support,std::vector<llvm::Value * > initial_value)432 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
433 std::vector<llvm::Value*> initial_value) {
434 for (llvm::Value* initial_vector_value : initial_value) {
435 storage_.emplace_back(vector_support, initial_vector_value);
436 }
437 }
438
Get() const439 std::vector<llvm::Value*> TileVariable::Get() const {
440 std::vector<llvm::Value*> result;
441 absl::c_transform(storage_, std::back_inserter(result),
442 [&](VectorVariable vect_var) { return vect_var.Get(); });
443 return result;
444 }
445
Set(absl::Span<llvm::Value * const> value)446 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
447 CHECK_EQ(value.size(), storage_.size());
448 for (int64_t i = 0, e = value.size(); i < e; i++) {
449 storage_[i].Set(value[i]);
450 }
451 }
452
453 } // namespace cpu
454 } // namespace xla
455