1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17
18 #include "absl/algorithm/container.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22
23 namespace xla {
24 namespace cpu {
VectorSupportLibrary(PrimitiveType primitive_type,int64 vector_size,llvm::IRBuilder<> * b,std::string name)25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
26 int64 vector_size,
27 llvm::IRBuilder<>* b,
28 std::string name)
29 : vector_size_(vector_size),
30 primitive_type_(primitive_type),
31 b_(b),
32 name_(std::move(name)) {
33 scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
34 primitive_type, b_->GetInsertBlock()->getModule());
35 scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
36 vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
37 vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
38 }
39
TypeToString(llvm::Type * type)40 static string TypeToString(llvm::Type* type) {
41 std::string o;
42 llvm::raw_string_ostream ostream(o);
43 type->print(ostream);
44 return ostream.str();
45 }
46
AssertCorrectTypes(std::initializer_list<llvm::Value * > values)47 void VectorSupportLibrary::AssertCorrectTypes(
48 std::initializer_list<llvm::Value*> values) {
49 for (llvm::Value* v : values) {
50 llvm::Type* type = v->getType();
51 if (type != scalar_type() && type != vector_type()) {
52 LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
53 << TypeToString(vector_type()) << " but got "
54 << TypeToString(type);
55 }
56 }
57 }
58
Mul(llvm::Value * lhs,llvm::Value * rhs)59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
60 AssertCorrectTypes({lhs, rhs});
61 return MulInternal(lhs, rhs);
62 }
63
MulInternal(llvm::Value * lhs,llvm::Value * rhs)64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
65 llvm::Value* rhs) {
66 if (scalar_type_->isFloatingPointTy()) {
67 return b()->CreateFMul(lhs, rhs, name());
68 } else {
69 return b()->CreateMul(lhs, rhs, name());
70 }
71 }
72
Add(llvm::Value * lhs,llvm::Value * rhs)73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
74 AssertCorrectTypes({lhs, rhs});
75 return AddInternal(lhs, rhs);
76 }
77
Sub(llvm::Value * lhs,llvm::Value * rhs)78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
79 AssertCorrectTypes({lhs, rhs});
80 return b()->CreateFSub(lhs, rhs);
81 }
82
Max(llvm::Value * lhs,llvm::Value * rhs)83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
84 AssertCorrectTypes({lhs, rhs});
85 if (scalar_type_->isFloatingPointTy()) {
86 return llvm_ir::EmitFloatMax(lhs, rhs, b_);
87 } else {
88 LOG(FATAL) << "Max for integers is unimplemented";
89 }
90 }
91
Floor(llvm::Value * a)92 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
93 AssertCorrectTypes({a});
94 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
95 {a->getType()}, b());
96 }
97
Div(llvm::Value * lhs,llvm::Value * rhs)98 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
99 AssertCorrectTypes({lhs, rhs});
100 if (scalar_type_->isFloatingPointTy()) {
101 return b()->CreateFDiv(lhs, rhs, name());
102 } else {
103 LOG(FATAL) << "Division for integers is unimplemented";
104 }
105 }
106
Clamp(llvm::Value * a,const llvm::APFloat & low,const llvm::APFloat & high)107 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
108 const llvm::APFloat& low,
109 const llvm::APFloat& high) {
110 AssertCorrectTypes({a});
111 llvm::Type* type = a->getType();
112 CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
113 CHECK(scalar_type_->isFloatingPointTy());
114 return llvm_ir::EmitFloatMin(
115 llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_),
116 GetConstantFloat(type, high), b_);
117 }
118
FCmpEQMask(llvm::Value * lhs,llvm::Value * rhs)119 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
120 llvm::Value* rhs) {
121 AssertCorrectTypes({lhs, rhs});
122 return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
123 }
124
FCmpOLTMask(llvm::Value * lhs,llvm::Value * rhs)125 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
126 llvm::Value* rhs) {
127 AssertCorrectTypes({lhs, rhs});
128 return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
129 }
130
FCmpULEMask(llvm::Value * lhs,llvm::Value * rhs)131 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
132 llvm::Value* rhs) {
133 AssertCorrectTypes({lhs, rhs});
134 return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
135 }
136
I1ToFloat(llvm::Value * i1)137 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
138 bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
139 llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
140 return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
141 is_vector ? vector_type() : scalar_type(), name());
142 }
143
IntegerTypeForFloatSize(bool vector)144 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
145 CHECK(scalar_type()->isFloatingPointTy());
146 const llvm::DataLayout& data_layout =
147 b()->GetInsertBlock()->getModule()->getDataLayout();
148 int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
149 llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
150 if (vector) {
151 return llvm::VectorType::get(scalar_int_type, vector_size());
152 } else {
153 return scalar_int_type;
154 }
155 }
156
BroadcastScalar(llvm::Value * x)157 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
158 CHECK_EQ(x->getType(), scalar_type());
159 return b()->CreateVectorSplat(vector_size(), x, name());
160 }
161
FloatAnd(llvm::Value * lhs,llvm::Value * rhs)162 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
163 llvm::Value* rhs) {
164 AssertCorrectTypes({lhs, rhs});
165 llvm::Type* int_type =
166 IntegerTypeForFloatSize(lhs->getType() == vector_type());
167 return b()->CreateBitCast(
168 b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
169 b()->CreateBitCast(rhs, int_type, name()), name()),
170 vector_type());
171 }
172
FloatNot(llvm::Value * lhs)173 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
174 AssertCorrectTypes({lhs});
175 llvm::Type* int_type =
176 IntegerTypeForFloatSize(lhs->getType() == vector_type());
177 return b()->CreateBitCast(
178 b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
179 vector_type());
180 }
181
FloatOr(llvm::Value * lhs,llvm::Value * rhs)182 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
183 AssertCorrectTypes({lhs, rhs});
184 llvm::Type* int_type =
185 IntegerTypeForFloatSize(lhs->getType() == vector_type());
186 return b()->CreateBitCast(
187 b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
188 b()->CreateBitCast(rhs, int_type, name()), name()),
189 vector_type(), name());
190 }
191
AddInternal(llvm::Value * lhs,llvm::Value * rhs)192 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
193 llvm::Value* rhs) {
194 if (scalar_type_->isFloatingPointTy()) {
195 return b()->CreateFAdd(lhs, rhs, name());
196 } else {
197 return b()->CreateAdd(lhs, rhs, name());
198 }
199 }
200
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements)201 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
202 llvm::Value* base_pointer, llvm::Value* offset_elements) {
203 if (base_pointer->getType() != scalar_pointer_type()) {
204 base_pointer =
205 b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
206 }
207 return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name());
208 }
209
LoadVector(llvm::Value * pointer)210 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
211 if (pointer->getType() != vector_pointer_type()) {
212 pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
213 }
214 return b()->CreateAlignedLoad(
215 pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
216 }
217
LoadScalar(llvm::Value * pointer)218 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
219 if (pointer->getType() != scalar_pointer_type()) {
220 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
221 }
222 return b()->CreateAlignedLoad(
223 pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
224 }
225
StoreVector(llvm::Value * value,llvm::Value * pointer)226 void VectorSupportLibrary::StoreVector(llvm::Value* value,
227 llvm::Value* pointer) {
228 AssertCorrectTypes({value});
229 if (pointer->getType() != vector_pointer_type()) {
230 pointer = b()->CreateBitCast(pointer, vector_pointer_type());
231 }
232 b()->CreateAlignedStore(value, pointer,
233 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
234 }
235
StoreScalar(llvm::Value * value,llvm::Value * pointer)236 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
237 llvm::Value* pointer) {
238 AssertCorrectTypes({value});
239 if (pointer->getType() != scalar_pointer_type()) {
240 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
241 }
242 b()->CreateAlignedStore(value, pointer,
243 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
244 }
245
LoadBroadcast(llvm::Value * pointer)246 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
247 if (pointer->getType() != scalar_pointer_type()) {
248 pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
249 }
250 return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer),
251 name());
252 }
253
AddReduce(llvm::Value * vector)254 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
255 llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
256 for (unsigned i = vector_size(); i != 1; i >>= 1) {
257 // On every iteration, we shuffle half of the remaining lanes to the top
258 // half of shuffle, and add two old and the new vector.
259
260 for (unsigned j = 0; j < vector_size(); ++j) {
261 if (j < (i / 2)) {
262 mask[j] = b()->getInt32(i / 2 + j);
263 } else {
264 mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
265 }
266 }
267
268 llvm::Value* half_remaining_lanes =
269 b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
270 llvm::ConstantVector::get(mask), "");
271 vector = Add(vector, half_remaining_lanes);
272 }
273
274 return b()->CreateExtractElement(vector, b()->getInt32(0), name());
275 }
276
AvxStyleHorizontalAdd(llvm::Value * lhs,llvm::Value * rhs)277 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
278 llvm::Value* rhs) {
279 CHECK_EQ(lhs->getType(), vector_type());
280 CHECK_EQ(rhs->getType(), vector_type());
281 CHECK_EQ(vector_size() % 2, 0);
282
283 llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
284
285 // Adding the values shuffled using mask_a and mask_b gives us the
286 // AVX-style horizontal add we want. The masks work as documented
287 // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
288 //
289 // Here are the masks for vector_width() == 8:
290 //
291 // index: |0 |1 |2 | 3 |4 |5 | 6 | 7
292 // --------+--+--+--+---+--+--+---+---
293 // mask_a: |0 |2 |8 |10 |4 |6 |12 |14
294 // mask_b: |1 |3 |9 |11 |5 |7 |13 |16
295 //
296 // So, as an example, the value at lane 3 of the result vector is
297 // the result of adding lane 10 and lane 11 in the combined lhs++rhs
298 // vector, which are the lanes 2 and 3 in the rhs vector.
299 for (int i = 0; i < vector_size(); i += 2) {
300 int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
301 mask_a.push_back(b()->getInt32(increment + i));
302 mask_b.push_back(b()->getInt32(increment + i + 1));
303 }
304 for (int i = 0; i < vector_size(); i += 2) {
305 int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
306 mask_a.push_back(b()->getInt32(increment + i));
307 mask_b.push_back(b()->getInt32(increment + i + 1));
308 }
309
310 llvm::Value* shuffle_0 =
311 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
312 llvm::Value* shuffle_1 =
313 b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
314
315 return Add(shuffle_0, shuffle_1);
316 }
317
ExtractLowHalf(llvm::Value * vector)318 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
319 llvm::SmallVector<llvm::Constant*, 32> mask;
320 for (int i = 0; i < vector_size() / 2; i++) {
321 mask.push_back(b()->getInt32(i));
322 }
323
324 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
325 llvm::ConstantVector::get(mask));
326 }
327
ExtractHighHalf(llvm::Value * vector)328 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
329 llvm::SmallVector<llvm::Constant*, 32> mask;
330 for (int i = 0; i < vector_size() / 2; i++) {
331 mask.push_back(b()->getInt32(i + vector_size() / 2));
332 }
333
334 return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
335 llvm::ConstantVector::get(mask));
336 }
337
ComputeHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)338 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
339 std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
340 const int x86_avx_vector_elements =
341 TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
342 if (vector_size() == x86_avx_vector_elements &&
343 vectors.size() == x86_avx_vector_elements) {
344 return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
345 }
346
347 std::vector<llvm::Value*> result;
348 std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
349 [this](llvm::Value* vector) { return AddReduce(vector); });
350 if (init_values) {
351 for (int64 i = 0, e = result.size(); i < e; i++) {
352 result[i] = Add(result[i],
353 b()->CreateExtractElement(init_values, b()->getInt32(i)));
354 }
355 }
356 return result;
357 }
358
359 std::vector<llvm::Value*>
ComputeAvxOptimizedHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)360 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
361 std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
362 // vectors are N llvm vector values, each with N elements.
363 int64 lane_width = vectors.size();
364
365 while (vectors.size() != 2) {
366 std::vector<llvm::Value*> new_vectors;
367 for (int i = 0; i < vectors.size(); i += 2) {
368 new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
369 }
370
371 vectors = std::move(new_vectors);
372 }
373
374 llvm::Value* low =
375 AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
376 if (init_values) {
377 low = AddInternal(ExtractLowHalf(init_values), low);
378 }
379 llvm::Value* high =
380 AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
381 if (init_values) {
382 high = AddInternal(ExtractHighHalf(init_values), high);
383 }
384
385 // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
386 // the next `lane_width / 2` horizontal reductions.
387
388 std::vector<llvm::Value*> results;
389 for (int i = 0; i < lane_width; i++) {
390 llvm::Value* scalar_result =
391 b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
392 b()->getInt32(i % (lane_width / 2)), name());
393 results.push_back(scalar_result);
394 }
395
396 return results;
397 }
398
GetZeroVector()399 llvm::Value* VectorSupportLibrary::GetZeroVector() {
400 return llvm::Constant::getNullValue(vector_type());
401 }
402
GetZeroScalar()403 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
404 return llvm::Constant::getNullValue(scalar_type());
405 }
406
LlvmVariable(llvm::Type * type,llvm::IRBuilder<> * b)407 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
408 alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
409 }
410
Get() const411 llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); }
412
Set(llvm::Value * new_value)413 void LlvmVariable::Set(llvm::Value* new_value) {
414 b_->CreateStore(new_value, alloca_);
415 }
416
TileVariable(VectorSupportLibrary * vector_support,std::vector<llvm::Value * > initial_value)417 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
418 std::vector<llvm::Value*> initial_value) {
419 for (llvm::Value* initial_vector_value : initial_value) {
420 storage_.emplace_back(vector_support, initial_vector_value);
421 }
422 }
423
Get() const424 std::vector<llvm::Value*> TileVariable::Get() const {
425 std::vector<llvm::Value*> result;
426 absl::c_transform(storage_, std::back_inserter(result),
427 [&](VectorVariable vect_var) { return vect_var.Get(); });
428 return result;
429 }
430
Set(absl::Span<llvm::Value * const> value)431 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
432 CHECK_EQ(value.size(), storage_.size());
433 for (int64 i = 0, e = value.size(); i < e; i++) {
434 storage_[i].Set(value[i]);
435 }
436 }
437
438 } // namespace cpu
439 } // namespace xla
440