• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "llvm/Support/raw_ostream.h"
20 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 
23 namespace xla {
24 namespace cpu {
VectorSupportLibrary(PrimitiveType primitive_type,int64 vector_size,llvm::IRBuilder<> * b,std::string name)25 VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
26                                            int64 vector_size,
27                                            llvm::IRBuilder<>* b,
28                                            std::string name)
29     : vector_size_(vector_size),
30       primitive_type_(primitive_type),
31       b_(b),
32       name_(std::move(name)) {
33   scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
34       primitive_type, b_->GetInsertBlock()->getModule());
35   scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
36   vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
37   vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
38 }
39 
TypeToString(llvm::Type * type)40 static string TypeToString(llvm::Type* type) {
41   std::string o;
42   llvm::raw_string_ostream ostream(o);
43   type->print(ostream);
44   return ostream.str();
45 }
46 
AssertCorrectTypes(std::initializer_list<llvm::Value * > values)47 void VectorSupportLibrary::AssertCorrectTypes(
48     std::initializer_list<llvm::Value*> values) {
49   for (llvm::Value* v : values) {
50     llvm::Type* type = v->getType();
51     if (type != scalar_type() && type != vector_type()) {
52       LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
53                  << TypeToString(vector_type()) << " but got "
54                  << TypeToString(type);
55     }
56   }
57 }
58 
Mul(llvm::Value * lhs,llvm::Value * rhs)59 llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
60   AssertCorrectTypes({lhs, rhs});
61   return MulInternal(lhs, rhs);
62 }
63 
MulInternal(llvm::Value * lhs,llvm::Value * rhs)64 llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
65                                                llvm::Value* rhs) {
66   if (scalar_type_->isFloatingPointTy()) {
67     return b()->CreateFMul(lhs, rhs, name());
68   } else {
69     return b()->CreateMul(lhs, rhs, name());
70   }
71 }
72 
Add(llvm::Value * lhs,llvm::Value * rhs)73 llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
74   AssertCorrectTypes({lhs, rhs});
75   return AddInternal(lhs, rhs);
76 }
77 
Sub(llvm::Value * lhs,llvm::Value * rhs)78 llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
79   AssertCorrectTypes({lhs, rhs});
80   return b()->CreateFSub(lhs, rhs);
81 }
82 
Max(llvm::Value * lhs,llvm::Value * rhs)83 llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
84   AssertCorrectTypes({lhs, rhs});
85   if (scalar_type_->isFloatingPointTy()) {
86     return llvm_ir::EmitFloatMax(lhs, rhs, b_);
87   } else {
88     LOG(FATAL) << "Max for integers is unimplemented";
89   }
90 }
91 
Floor(llvm::Value * a)92 llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
93   AssertCorrectTypes({a});
94   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
95                                       {a->getType()}, b());
96 }
97 
Div(llvm::Value * lhs,llvm::Value * rhs)98 llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
99   AssertCorrectTypes({lhs, rhs});
100   if (scalar_type_->isFloatingPointTy()) {
101     return b()->CreateFDiv(lhs, rhs, name());
102   } else {
103     LOG(FATAL) << "Division for integers is unimplemented";
104   }
105 }
106 
Clamp(llvm::Value * a,const llvm::APFloat & low,const llvm::APFloat & high)107 llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
108                                          const llvm::APFloat& low,
109                                          const llvm::APFloat& high) {
110   AssertCorrectTypes({a});
111   llvm::Type* type = a->getType();
112   CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
113   CHECK(scalar_type_->isFloatingPointTy());
114   return llvm_ir::EmitFloatMin(
115       llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_),
116       GetConstantFloat(type, high), b_);
117 }
118 
FCmpEQMask(llvm::Value * lhs,llvm::Value * rhs)119 llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
120                                               llvm::Value* rhs) {
121   AssertCorrectTypes({lhs, rhs});
122   return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name()));
123 }
124 
FCmpOLTMask(llvm::Value * lhs,llvm::Value * rhs)125 llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
126                                                llvm::Value* rhs) {
127   AssertCorrectTypes({lhs, rhs});
128   return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name()));
129 }
130 
FCmpULEMask(llvm::Value * lhs,llvm::Value * rhs)131 llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
132                                                llvm::Value* rhs) {
133   AssertCorrectTypes({lhs, rhs});
134   return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name()));
135 }
136 
I1ToFloat(llvm::Value * i1)137 llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
138   bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
139   llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
140   return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()),
141                             is_vector ? vector_type() : scalar_type(), name());
142 }
143 
IntegerTypeForFloatSize(bool vector)144 llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
145   CHECK(scalar_type()->isFloatingPointTy());
146   const llvm::DataLayout& data_layout =
147       b()->GetInsertBlock()->getModule()->getDataLayout();
148   int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
149   llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits);
150   if (vector) {
151     return llvm::VectorType::get(scalar_int_type, vector_size());
152   } else {
153     return scalar_int_type;
154   }
155 }
156 
BroadcastScalar(llvm::Value * x)157 llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
158   CHECK_EQ(x->getType(), scalar_type());
159   return b()->CreateVectorSplat(vector_size(), x, name());
160 }
161 
FloatAnd(llvm::Value * lhs,llvm::Value * rhs)162 llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
163                                             llvm::Value* rhs) {
164   AssertCorrectTypes({lhs, rhs});
165   llvm::Type* int_type =
166       IntegerTypeForFloatSize(lhs->getType() == vector_type());
167   return b()->CreateBitCast(
168       b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()),
169                      b()->CreateBitCast(rhs, int_type, name()), name()),
170       vector_type());
171 }
172 
FloatNot(llvm::Value * lhs)173 llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
174   AssertCorrectTypes({lhs});
175   llvm::Type* int_type =
176       IntegerTypeForFloatSize(lhs->getType() == vector_type());
177   return b()->CreateBitCast(
178       b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()),
179       vector_type());
180 }
181 
FloatOr(llvm::Value * lhs,llvm::Value * rhs)182 llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
183   AssertCorrectTypes({lhs, rhs});
184   llvm::Type* int_type =
185       IntegerTypeForFloatSize(lhs->getType() == vector_type());
186   return b()->CreateBitCast(
187       b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()),
188                     b()->CreateBitCast(rhs, int_type, name()), name()),
189       vector_type(), name());
190 }
191 
AddInternal(llvm::Value * lhs,llvm::Value * rhs)192 llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
193                                                llvm::Value* rhs) {
194   if (scalar_type_->isFloatingPointTy()) {
195     return b()->CreateFAdd(lhs, rhs, name());
196   } else {
197     return b()->CreateAdd(lhs, rhs, name());
198   }
199 }
200 
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements)201 llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
202     llvm::Value* base_pointer, llvm::Value* offset_elements) {
203   if (base_pointer->getType() != scalar_pointer_type()) {
204     base_pointer =
205         b()->CreateBitCast(base_pointer, scalar_pointer_type(), name());
206   }
207   return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name());
208 }
209 
LoadVector(llvm::Value * pointer)210 llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
211   if (pointer->getType() != vector_pointer_type()) {
212     pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name());
213   }
214   return b()->CreateAlignedLoad(
215       pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
216 }
217 
LoadScalar(llvm::Value * pointer)218 llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
219   if (pointer->getType() != scalar_pointer_type()) {
220     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
221   }
222   return b()->CreateAlignedLoad(
223       pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
224 }
225 
StoreVector(llvm::Value * value,llvm::Value * pointer)226 void VectorSupportLibrary::StoreVector(llvm::Value* value,
227                                        llvm::Value* pointer) {
228   AssertCorrectTypes({value});
229   if (pointer->getType() != vector_pointer_type()) {
230     pointer = b()->CreateBitCast(pointer, vector_pointer_type());
231   }
232   b()->CreateAlignedStore(value, pointer,
233                           ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
234 }
235 
StoreScalar(llvm::Value * value,llvm::Value * pointer)236 void VectorSupportLibrary::StoreScalar(llvm::Value* value,
237                                        llvm::Value* pointer) {
238   AssertCorrectTypes({value});
239   if (pointer->getType() != scalar_pointer_type()) {
240     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
241   }
242   b()->CreateAlignedStore(value, pointer,
243                           ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
244 }
245 
LoadBroadcast(llvm::Value * pointer)246 llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
247   if (pointer->getType() != scalar_pointer_type()) {
248     pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name());
249   }
250   return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer),
251                                 name());
252 }
253 
AddReduce(llvm::Value * vector)254 llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
255   llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
256   for (unsigned i = vector_size(); i != 1; i >>= 1) {
257     // On every iteration, we shuffle half of the remaining lanes to the top
258     // half of shuffle, and add two old and the new vector.
259 
260     for (unsigned j = 0; j < vector_size(); ++j) {
261       if (j < (i / 2)) {
262         mask[j] = b()->getInt32(i / 2 + j);
263       } else {
264         mask[j] = llvm::UndefValue::get(b()->getInt32Ty());
265       }
266     }
267 
268     llvm::Value* half_remaining_lanes =
269         b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
270                                  llvm::ConstantVector::get(mask), "");
271     vector = Add(vector, half_remaining_lanes);
272   }
273 
274   return b()->CreateExtractElement(vector, b()->getInt32(0), name());
275 }
276 
AvxStyleHorizontalAdd(llvm::Value * lhs,llvm::Value * rhs)277 llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
278                                                          llvm::Value* rhs) {
279   CHECK_EQ(lhs->getType(), vector_type());
280   CHECK_EQ(rhs->getType(), vector_type());
281   CHECK_EQ(vector_size() % 2, 0);
282 
283   llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
284 
285   // Adding the values shuffled using mask_a and mask_b gives us the
286   // AVX-style horizontal add we want.  The masks work as documented
287   // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
288   //
289   // Here are the masks for vector_width() == 8:
290   //
291   //    index: |0 |1 |2 | 3 |4 |5 | 6 | 7
292   //   --------+--+--+--+---+--+--+---+---
293   //   mask_a: |0 |2 |8 |10 |4 |6 |12 |14
294   //   mask_b: |1 |3 |9 |11 |5 |7 |13 |16
295   //
296   // So, as an example, the value at lane 3 of the result vector is
297   // the result of adding lane 10 and lane 11 in the combined lhs++rhs
298   // vector, which are the lanes 2 and 3 in the rhs vector.
299   for (int i = 0; i < vector_size(); i += 2) {
300     int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
301     mask_a.push_back(b()->getInt32(increment + i));
302     mask_b.push_back(b()->getInt32(increment + i + 1));
303   }
304   for (int i = 0; i < vector_size(); i += 2) {
305     int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
306     mask_a.push_back(b()->getInt32(increment + i));
307     mask_b.push_back(b()->getInt32(increment + i + 1));
308   }
309 
310   llvm::Value* shuffle_0 =
311       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a));
312   llvm::Value* shuffle_1 =
313       b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b));
314 
315   return Add(shuffle_0, shuffle_1);
316 }
317 
ExtractLowHalf(llvm::Value * vector)318 llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
319   llvm::SmallVector<llvm::Constant*, 32> mask;
320   for (int i = 0; i < vector_size() / 2; i++) {
321     mask.push_back(b()->getInt32(i));
322   }
323 
324   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
325                                   llvm::ConstantVector::get(mask));
326 }
327 
ExtractHighHalf(llvm::Value * vector)328 llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
329   llvm::SmallVector<llvm::Constant*, 32> mask;
330   for (int i = 0; i < vector_size() / 2; i++) {
331     mask.push_back(b()->getInt32(i + vector_size() / 2));
332   }
333 
334   return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()),
335                                   llvm::ConstantVector::get(mask));
336 }
337 
ComputeHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)338 std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
339     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
340   const int x86_avx_vector_elements =
341       TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
342   if (vector_size() == x86_avx_vector_elements &&
343       vectors.size() == x86_avx_vector_elements) {
344     return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
345   }
346 
347   std::vector<llvm::Value*> result;
348   std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
349                  [this](llvm::Value* vector) { return AddReduce(vector); });
350   if (init_values) {
351     for (int64 i = 0, e = result.size(); i < e; i++) {
352       result[i] = Add(result[i],
353                       b()->CreateExtractElement(init_values, b()->getInt32(i)));
354     }
355   }
356   return result;
357 }
358 
359 std::vector<llvm::Value*>
ComputeAvxOptimizedHorizontalSums(std::vector<llvm::Value * > vectors,llvm::Value * init_values)360 VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
361     std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
362   // vectors are N llvm vector values, each with N elements.
363   int64 lane_width = vectors.size();
364 
365   while (vectors.size() != 2) {
366     std::vector<llvm::Value*> new_vectors;
367     for (int i = 0; i < vectors.size(); i += 2) {
368       new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
369     }
370 
371     vectors = std::move(new_vectors);
372   }
373 
374   llvm::Value* low =
375       AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
376   if (init_values) {
377     low = AddInternal(ExtractLowHalf(init_values), low);
378   }
379   llvm::Value* high =
380       AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
381   if (init_values) {
382     high = AddInternal(ExtractHighHalf(init_values), high);
383   }
384 
385   // `low` has the first `lane_width / 2` horizontal reductions, and `high` has
386   // the next `lane_width / 2` horizontal reductions.
387 
388   std::vector<llvm::Value*> results;
389   for (int i = 0; i < lane_width; i++) {
390     llvm::Value* scalar_result =
391         b()->CreateExtractElement(i < (lane_width / 2) ? low : high,
392                                   b()->getInt32(i % (lane_width / 2)), name());
393     results.push_back(scalar_result);
394   }
395 
396   return results;
397 }
398 
GetZeroVector()399 llvm::Value* VectorSupportLibrary::GetZeroVector() {
400   return llvm::Constant::getNullValue(vector_type());
401 }
402 
GetZeroScalar()403 llvm::Value* VectorSupportLibrary::GetZeroScalar() {
404   return llvm::Constant::getNullValue(scalar_type());
405 }
406 
LlvmVariable(llvm::Type * type,llvm::IRBuilder<> * b)407 LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) {
408   alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_);
409 }
410 
Get() const411 llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); }
412 
Set(llvm::Value * new_value)413 void LlvmVariable::Set(llvm::Value* new_value) {
414   b_->CreateStore(new_value, alloca_);
415 }
416 
TileVariable(VectorSupportLibrary * vector_support,std::vector<llvm::Value * > initial_value)417 TileVariable::TileVariable(VectorSupportLibrary* vector_support,
418                            std::vector<llvm::Value*> initial_value) {
419   for (llvm::Value* initial_vector_value : initial_value) {
420     storage_.emplace_back(vector_support, initial_vector_value);
421   }
422 }
423 
Get() const424 std::vector<llvm::Value*> TileVariable::Get() const {
425   std::vector<llvm::Value*> result;
426   absl::c_transform(storage_, std::back_inserter(result),
427                     [&](VectorVariable vect_var) { return vect_var.Get(); });
428   return result;
429 }
430 
Set(absl::Span<llvm::Value * const> value)431 void TileVariable::Set(absl::Span<llvm::Value* const> value) {
432   CHECK_EQ(value.size(), storage_.size());
433   for (int64 i = 0, e = value.size(); i < e; i++) {
434     storage_[i].Set(value[i]);
435   }
436 }
437 
438 }  // namespace cpu
439 }  // namespace xla
440