• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 
23 namespace xla {
24 namespace cpu {
VectorSupportLibrary(PrimitiveType primitive_type,int64_t vector_size,llvm::IRBuilder<> * b,std::string name)25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
26                                            int64_t vector_size,
27                                            llvm::IRBuilder<>* b,
28                                            std::string name)
29     : vector_size_(vector_size),
30       primitive_type_(primitive_type),
31       b_(b),
32       name_(std::move(name)) {
33   scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
34       primitive_type, b_->GetInsertBlock()->getModule());
35   scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
36   vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false);
37   vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
38 }
39 
TypeToString(llvm::Type * type)40 static std::string TypeToString(llvm::Type* type) {
41   std::string o;
42   llvm::raw_string_ostream ostream(o);
43   type->print(ostream);
44   return ostream.str();
45 }
46 
AssertCorrectTypes(std::initializer_list<llvm::Value * > values)47 void VectorSupportLibrary::AssertCorrectTypes(
48     std::initializer_list<llvm::Value*> values) {
49   for (llvm::Value* v : values) {
50     llvm::Type* type = v->getType();
51     if (type != scalar_type() && type != vector_type()) {
52       LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
53                  << TypeToString(vector_type()) << " but got "
54                  << TypeToString(type);
55     }
56   }
57 }
58 
Mul(llvm::Value * lhs,llvm::Value * rhs)59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
60   AssertCorrectTypes({lhs, rhs});
61   return MulInternal(lhs, rhs);
62 }
63 
MulInternal(llvm::Value * lhs,llvm::Value * rhs)64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
65                                                llvm::Value* rhs) {
66   if (scalar_type_->isFloatingPointTy()) {
67     return b()->CreateFMul(lhs, rhs, name());
68   } else {
69     return b()->CreateMul(lhs, rhs, name());
70   }
71 }
72 
Add(llvm::Value * lhs,llvm::Value * rhs)73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
74   AssertCorrectTypes({lhs, rhs});
75   return AddInternal(lhs, rhs);
76 }
77 
Sub(llvm::Value * lhs,llvm::Value * rhs)78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
79   AssertCorrectTypes({lhs, rhs});
80   return b()->CreateFSub(lhs, rhs);
81 }
82 
Max(llvm::Value * lhs,llvm::Value * rhs,bool enable_fast_min_max)83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
84                                        bool enable_fast_min_max) {
85   AssertCorrectTypes({lhs, rhs});
86   if (scalar_type_->isFloatingPointTy()) {
87     return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
88   } else {
89     LOG(FATAL) << "Max for integers is unimplemented";
90   }
91 }
92 
Floor(llvm::Value * a)93 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
94   AssertCorrectTypes({a});
95   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
96                                       {a->getType()}, b());
97 }
98 
Div(llvm::Value * lhs,llvm::Value * rhs)99 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
100   AssertCorrectTypes({lhs, rhs});
101   if (scalar_type_->isFloatingPointTy()) {
102     return b()->CreateFDiv(lhs, rhs, name());
103   } else {
104     LOG(FATAL) << "Division for integers is unimplemented";
105   }
106 }
107 
Clamp(llvm::Value * a,const llvm::APFloat & low,const llvm::APFloat & high)108 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
109                                          const llvm::APFloat& low,
110                                          const llvm::APFloat& high) {
111   CHECK(!low.isNaN());
112   CHECK(!high.isNaN());
113   CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
114 
115   AssertCorrectTypes({a});
116   llvm::Type* type = a->getType();
117   CHECK(scalar_type_->isFloatingPointTy());
118 
119   llvm::Value* low_value = GetConstantFloat(type, low);
120   llvm::Value* high_value = GetConstantFloat(type, high);
121   a = b_->CreateSelect(b_->CreateFCmpUGE(a, low_value), a, low_value);
122   a = b_->CreateSelect(b_->CreateFCmpULE(a, high_value), a, high_value);
123   return a;
124 }
125 
FCmpEQMask(llvm::Value * lhs,llvm::Value * rhs)126 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
127                                               llvm::Value* rhs) {
128   AssertCorrectTypes({lhs, rhs});
129   return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
130 }
131 
FCmpOLTMask(llvm::Value * lhs,llvm::Value * rhs)132 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
133                                                llvm::Value* rhs) {
134   AssertCorrectTypes({lhs, rhs});
135   return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
136 }
137 
FCmpULEMask(llvm::Value * lhs,llvm::Value * rhs)138 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
139                                                llvm::Value* rhs) {
140   AssertCorrectTypes({lhs, rhs});
141   return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
142 }
143 
I1ToFloat(llvm::Value * i1)144 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
145   bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
146   llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
147   return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
148                             is_vector ? vector_type() : scalar_type(), name());
149 }
150 
IntegerTypeForFloatSize(bool vector)151 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
152   CHECK(scalar_type()->isFloatingPointTy());
153   const llvm::DataLayout& data_layout =
154       b()->GetInsertBlock()->getModule()->getDataLayout();
155   int64_t float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
156   llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
157   if (vector) {
158     return llvm::VectorType::get(scalar_int_type, vector_size(), false);
159   } else {
160     return scalar_int_type;
161   }
162 }
163 
BroadcastScalar(llvm::Value * x)164 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
165   CHECK_EQ(x->getType(), scalar_type());
166   return b()->CreateVectorSplat(vector_size(), x, name());
167 }
168 
FloatAnd(llvm::Value * lhs,llvm::Value * rhs)169 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
170                                             llvm::Value* rhs) {
171   AssertCorrectTypes({lhs, rhs});
172   llvm::Type* int_type =
173       IntegerTypeForFloatSize(lhs->getType() == vector_type());
174   return b()->CreateBitCast(
175       b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
176                      b()->CreateBitCast(rhs, int_type, name()), name()),
177       vector_type());
178 }
179 
FloatNot(llvm::Value * lhs)180 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
181   AssertCorrectTypes({lhs});
182   llvm::Type* int_type =
183       IntegerTypeForFloatSize(lhs->getType() == vector_type());
184   return b()->CreateBitCast(
185       b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
186       vector_type());
187 }
188 
FloatOr(llvm::Value * lhs,llvm::Value * rhs)189 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
190   AssertCorrectTypes({lhs, rhs});
191   llvm::Type* int_type =
192       IntegerTypeForFloatSize(lhs->getType() == vector_type());
193   return b()->CreateBitCast(
194       b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
195                     b()->CreateBitCast(rhs, int_type, name()), name()),
196       vector_type(), name());
197 }
198 
AddInternal(llvm::Value * lhs,llvm::Value * rhs)199 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
200                                                llvm::Value* rhs) {
201   if (scalar_type_->isFloatingPointTy()) {
202     return b()->CreateFAdd(lhs, rhs, name());
203   } else {
204     return b()->CreateAdd(lhs, rhs, name());
205   }
206 }
207 
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements)208 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
209     llvm::Value* base_pointer, llvm::Value* offset_elements) {
210   if (base_pointer->getType() != scalar_pointer_type()) {
211     base_pointer =
212         b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
213   }
214   return b()->CreateInBoundsGEP(scalar_type(), base_pointer, offset_elements,
215                                 name());
216 }
217 
LoadVector(llvm::Value * pointer)218 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
219   if (pointer->getType() != vector_pointer_type()) {
220     pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
221   }
222   return b()->CreateAlignedLoad(
223       vector_type(), pointer,
224       llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)), name());
225 }
226 
LoadScalar(llvm::Value * pointer)227 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
228   if (pointer->getType() != scalar_pointer_type()) {
229     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
230   }
231   return b()->CreateAlignedLoad(
232       scalar_type(), pointer,
233       llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)), name());
234 }
235 
StoreVector(llvm::Value * value,llvm::Value * pointer)236 void VectorSupportLibrary::StoreVector(llvm::Value* value,
237                                        llvm::Value* pointer) {
238   AssertCorrectTypes({value});
239   if (pointer->getType() != vector_pointer_type()) {
240     pointer = b()->CreateBitCast(pointer, vector_pointer_type());
241   }
242   b()->CreateAlignedStore(
243       value, pointer,
244       llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
245 }
246 
StoreScalar(llvm::Value * value,llvm::Value * pointer)247 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
248                                        llvm::Value* pointer) {
249   AssertCorrectTypes({value});
250   if (pointer->getType() != scalar_pointer_type()) {
251     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
252   }
253   b()->CreateAlignedStore(
254       value, pointer,
255       llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)));
256 }
257 
LoadBroadcast(llvm::Value * pointer)258 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
259   if (pointer->getType() != scalar_pointer_type()) {
260     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
261   }
262   return b()->CreateVectorSplat(
263       vector_size(), b()->CreateLoad(scalar_type(), pointer), name());
264 }
265 
AddReduce(llvm::Value * vector)266 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
267   llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
268   for (unsigned i = vector_size(); i != 1; i >>= 1) {
269     // On every iteration, we shuffle half of the remaining lanes to the top
270     // half of shuffle, and add two old and the new vector.
271 
272     for (unsigned j = 0; j < vector_size(); ++j) {
273       if (j < (i / 2)) {
274         mask[j] = b()->getInt32(i / 2 + j);
275       } else {
276         mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
277       }
278     }
279 
280     llvm::Value* half_remaining_lanes =
281         b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
282                                  llvm::ConstantVector::get(mask), "");
283     vector = Add(vector, half_remaining_lanes);
284   }
285 
286   return b()->CreateExtractElement(vector, b()->getInt32(0), name());
287 }
288 
AvxStyleHorizontalAdd(llvm::Value * lhs,llvm::Value * rhs)289 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
290                                                          llvm::Value* rhs) {
291   CHECK_EQ(lhs->getType(), vector_type());
292   CHECK_EQ(rhs->getType(), vector_type());
293   CHECK_EQ(vector_size() % 2, 0);
294 
295   llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
296 
297   // Adding the values shuffled using mask_a and mask_b gives us the
298   // AVX-style horizontal add we want.  The masks work as documented
299   // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
300   //
301   // Here are the masks for vector_width() == 8:
302   //
303   //    index: |0 |1 |2 | 3 |4 |5 | 6 | 7
304   //   --------+--+--+--+---+--+--+---+---
305   //   mask_a: |0 |2 |8 |10 |4 |6 |12 |14
306   //   mask_b: |1 |3 |9 |11 |5 |7 |13 |16
307   //
308   // So, as an example, the value at lane 3 of the result vector is
309   // the result of adding lane 10 and lane 11 in the combined lhs++rhs
310   // vector, which are the lanes 2 and 3 in the rhs vector.
311   for (int i = 0; i < vector_size(); i += 2) {
312     int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
313     mask_a.push_back(b()->getInt32(increment + i));
314     mask_b.push_back(b()->getInt32(increment + i + 1));
315   }
316   for (int i = 0; i < vector_size(); i += 2) {
317     int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
318     mask_a.push_back(b()->getInt32(increment + i));
319     mask_b.push_back(b()->getInt32(increment + i + 1));
320   }
321 
322   llvm::Value* shuffle_0 =
323       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
324   llvm::Value* shuffle_1 =
325       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
326 
327   return Add(shuffle_0, shuffle_1);
328 }
329 
ExtractLowHalf(llvm::Value * vector)330 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
331   llvm::SmallVector<llvm::Constant*, 32> mask;
332   for (int i = 0; i < vector_size() / 2; i++) {
333     mask.push_back(b()->getInt32(i));
334   }
335 
336   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
337                                   llvm::ConstantVector::get(mask));
338 }
339 
ExtractHighHalf(llvm::Value * vector)340 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
341   llvm::SmallVector<llvm::Constant*, 32> mask;
342   for (int i = 0; i < vector_size() / 2; i++) {
343     mask.push_back(b()->getInt32(i + vector_size() / 2));
344   }
345 
346   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
347                                   llvm::ConstantVector::get(mask));
348 }
349 
ComputeHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)350 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
351     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
352   const int x86_avx_vector_elements =
353       TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
354   if (vector_size() == x86_avx_vector_elements &&
355       vectors.size() == x86_avx_vector_elements) {
356     return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
357   }
358 
359   std::vector<llvm::Value*> result;
360   std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
361                  [this](llvm::Value* vector) { return AddReduce(vector); });
362   if (init_values) {
363     for (int64_t i = 0, e = result.size(); i < e; i++) {
364       result[i] = Add(result[i],
365                       b()->CreateExtractElement(init_values, b()->getInt32(i)));
366     }
367   }
368   return result;
369 }
370 
371 std::vector<llvm::Value*>
ComputeAvxOptimizedHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)372 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
373     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
374   // vectors are N llvm vector values, each with N elements.
375   int64_t lane_width = vectors.size();
376 
377   while (vectors.size() != 2) {
378     std::vector<llvm::Value*> new_vectors;
379     new_vectors.reserve(vectors.size() / 2);
380     for (int i = 0; i < vectors.size(); i += 2) {
381       new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
382     }
383 
384     vectors = std::move(new_vectors);
385   }
386 
387   llvm::Value* low =
388       AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
389   if (init_values) {
390     low = AddInternal(ExtractLowHalf(init_values), low);
391   }
392   llvm::Value* high =
393       AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
394   if (init_values) {
395     high = AddInternal(ExtractHighHalf(init_values), high);
396   }
397 
398   // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
399   // the next `lane_width / 2` horizontal reductions.
400 
401   std::vector<llvm::Value*> results;
402   for (int i = 0; i < lane_width; i++) {
403     llvm::Value* scalar_result =
404         b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
405                                   b()->getInt32(i % (lane_width / 2)), name());
406     results.push_back(scalar_result);
407   }
408 
409   return results;
410 }
411 
GetZeroVector()412 llvm::Value* VectorSupportLibrary::GetZeroVector() {
413   return llvm::Constant::getNullValue(vector_type());
414 }
415 
GetZeroScalar()416 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
417   return llvm::Constant::getNullValue(scalar_type());
418 }
419 
LlvmVariable(llvm::Type * type,llvm::IRBuilder<> * b)420 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
421   alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
422 }
423 
Get() const424 llvm::Value* LlvmVariable::Get() const {
425   return b_->CreateLoad(alloca_->getAllocatedType(), alloca_);
426 }
427 
Set(llvm::Value * new_value)428 void LlvmVariable::Set(llvm::Value* new_value) {
429   b_->CreateStore(new_value, alloca_);
430 }
431 
TileVariable(VectorSupportLibrary * vector_support,std::vector<llvm::Value * > initial_value)432 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
433                            std::vector<llvm::Value*> initial_value) {
434   for (llvm::Value* initial_vector_value : initial_value) {
435     storage_.emplace_back(vector_support, initial_vector_value);
436   }
437 }
438 
Get() const439 std::vector<llvm::Value*> TileVariable::Get() const {
440   std::vector<llvm::Value*> result;
441   absl::c_transform(storage_, std::back_inserter(result),
442                     [&](VectorVariable vect_var) { return vect_var.Get(); });
443   return result;
444 }
445 
Set(absl::Span<llvm::Value * const> value)446 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
447   CHECK_EQ(value.size(), storage_.size());
448   for (int64_t i = 0, e = value.size(); i < e; i++) {
449     storage_[i].Set(value[i]);
450   }
451 }
452 
453 }  // namespace cpu
454 }  // namespace xla
455