• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18 
19 #include <string>
20 
21 #include "absl/types/span.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 namespace cpu {
30 
31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
32 // obvious.
33 
GetIeeeF32(float f)34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
GetIeeeF32FromBitwiseRep(int32 bitwise_value)35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
36   return llvm::APFloat(llvm::APFloat::IEEEsingle(),
37                        llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
38 }
39 
40 // A thin wrapper around llvm_util.h to make code generating vector math flow
41 // more readable.
42 class VectorSupportLibrary {
43  public:
44   // This VectorSupportLibrary instance remembers `primitive_type` and
45   // `vector_size`, and these are implicitly used by the methods on this
46   // instance (i.e. LoadVector will load a vector of type <`vector_size` x
47   // `primitive_type`>).
48   VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
49                        llvm::IRBuilder<>* b, std::string name);
50 
51   llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
Mul(int64 lhs,llvm::Value * rhs)52   llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
53     return Mul(b()->getInt64(lhs), rhs);
54   }
Mul(const llvm::APFloat & lhs,llvm::Value * rhs)55   llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
56     return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
57   }
58 
59   // If your call resolved to these then you probably wanted the versions taking
60   // APFloat.
61   llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
62   llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
63 
64   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
Add(int64 lhs,llvm::Value * rhs)65   llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
66     return Add(b()->getInt64(lhs), rhs);
67   }
Add(const llvm::APFloat & lhs,llvm::Value * rhs)68   llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
69     return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
70   }
71 
72   // If your call resolved to these then you probably wanted the versions taking
73   // APFloat.
74   llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
75   llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
76 
77   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
Sub(llvm::Value * lhs,const llvm::APFloat & rhs)78   llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
79     return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
80   }
81   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
Max(const llvm::APFloat & lhs,llvm::Value * rhs)82   llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
83     return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
84   }
85   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
86 
MulAdd(llvm::Value * a,llvm::Value * b,llvm::Value * c)87   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
88     return Add(c, Mul(a, b));
89   }
90 
MulAdd(llvm::Value * a,llvm::Value * b,const llvm::APFloat & c)91   llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
92     return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
93   }
94 
MulAdd(llvm::Value * a,const llvm::APFloat & b,const llvm::APFloat & c)95   llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
96                       const llvm::APFloat& c) {
97     return Add(GetConstantFloat(a->getType(), c),
98                Mul(a, GetConstantFloat(a->getType(), b)));
99   }
100 
101   llvm::Value* Floor(llvm::Value* a);
102 
103   llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
104                      const llvm::APFloat& high);
SplatFloat(const llvm::APFloat & d)105   llvm::Value* SplatFloat(const llvm::APFloat& d) {
106     return GetConstantFloat(vector_type(), d);
107   }
108 
109   // These compare instructions return a floating point typed mask instead of an
110   // i1.  For instance, on a vector typed input, lanes where the predicate is
111   // true get a float with all ones and other lanes get a float with all zeros.
112   // This is slightly odd from the perspective of LLVM's type system, but it
113   // makes kernel IR generation code written using VectorSupportLibrary (its
114   // raison d'etre) less cluttered.
115 
116   llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpEQMask(llvm::Value * lhs,const llvm::APFloat & rhs)117   llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
118     return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
119   }
120   llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
121   llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpOLTMask(llvm::Value * lhs,const llvm::APFloat & rhs)122   llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
123     return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
124   }
125 
126   // These boolean operations operate on the bitwise values of the floating
127   // point inputs.  They return a (vector of) float(s) but like in the mask
128   // generating predicates above this type system oddity makes the kernel IR
129   // generation code less cluttered.
130   llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
FloatAnd(llvm::Value * lhs,const llvm::APFloat & rhs)131   llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
132     return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
133   }
134   llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
FloatOr(llvm::Value * lhs,const llvm::APFloat & rhs)135   llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
136     return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
137   }
138   llvm::Value* FloatNot(llvm::Value* lhs);
FloatAndNot(llvm::Value * lhs,llvm::Value * rhs)139   llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
140     return FloatAnd(FloatNot(lhs), rhs);
141   }
142 
143   llvm::Value* BroadcastScalar(llvm::Value* x);
BroadcastScalar(const llvm::APFloat & d)144   llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
145     return BroadcastScalar(GetConstantFloat(scalar_type(), d));
146   }
147 
148   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
149                                     llvm::Value* offset_elements);
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements,int64 scale)150   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
151                                     llvm::Value* offset_elements, int64 scale) {
152     return ComputeOffsetPointer(
153         base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
154   }
ComputeOffsetPointer(llvm::Value * base_pointer,int64 offset_elements)155   llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
156                                     int64 offset_elements) {
157     return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
158   }
159 
160   llvm::Value* LoadVector(llvm::Value* pointer);
161 
LoadVector(llvm::Value * base_pointer,llvm::Value * offset_elements)162   llvm::Value* LoadVector(llvm::Value* base_pointer,
163                           llvm::Value* offset_elements) {
164     return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
165   }
166 
LoadVector(llvm::Value * base_pointer,int64 offset_elements)167   llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
168     return LoadVector(base_pointer, b()->getInt64(offset_elements));
169   }
170 
171   llvm::Value* LoadScalar(llvm::Value* pointer);
172 
LoadScalar(llvm::Value * base_pointer,llvm::Value * offset_elements)173   llvm::Value* LoadScalar(llvm::Value* base_pointer,
174                           llvm::Value* offset_elements) {
175     return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
176   }
177 
LoadScalar(llvm::Value * base_pointer,int64 offset_elements)178   llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
179     return LoadScalar(base_pointer, b()->getInt64(offset_elements));
180   }
181 
182   void StoreVector(llvm::Value* value, llvm::Value* pointer);
183 
StoreVector(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)184   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
185                    llvm::Value* offset_elements) {
186     StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
187   }
188 
StoreVector(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)189   void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
190                    int64 offset_elements) {
191     StoreVector(value, base_pointer, b()->getInt64(offset_elements));
192   }
193 
194   void StoreScalar(llvm::Value* value, llvm::Value* pointer);
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)195   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
196                    llvm::Value* offset_elements) {
197     StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
198   }
199 
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)200   void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
201                    int64 offset_elements) {
202     StoreScalar(base_pointer, b()->getInt64(offset_elements));
203   }
204 
205   llvm::Value* LoadBroadcast(llvm::Value* pointer);
LoadBroadcast(llvm::Value * base_pointer,llvm::Value * offset_elements)206   llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
207                              llvm::Value* offset_elements) {
208     return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
209   }
LoadBroadcast(llvm::Value * base_pointer,int64 offset_elements)210   llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
211     return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
212   }
213 
214   // Compute the horizontal sum of each vector in `vectors`.  The i'th element
215   // in the result vector is the (scalar) horizontal sum of the i'th vector in
216   // `vectors`.  If `init_values` is not nullptr then the value in the i'th lane
217   // in `init_values` is added to the i'th horizontal sum.
218   std::vector<llvm::Value*> ComputeHorizontalSums(
219       std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
220 
221   llvm::Value* GetZeroVector();
222   llvm::Value* GetZeroScalar();
223 
b()224   llvm::IRBuilder<>* b() const { return b_; }
vector_size()225   int64 vector_size() const { return vector_size_; }
vector_type()226   llvm::Type* vector_type() const { return vector_type_; }
vector_pointer_type()227   llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
scalar_type()228   llvm::Type* scalar_type() const { return scalar_type_; }
scalar_pointer_type()229   llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
scalar_byte_size()230   int64 scalar_byte_size() const {
231     return primitive_util::BitWidth(primitive_type_) / 8;
232   }
233 
name()234   const std::string& name() const { return name_; }
235 
236  private:
237   llvm::Value* ExtractLowHalf(llvm::Value*);
238   llvm::Value* ExtractHighHalf(llvm::Value*);
239 
240   llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
241   llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
242 
243   llvm::Value* AddReduce(llvm::Value* vector);
244 
245   // Checks that each value in `values` is either of type scalar_type() or
246   // vector_type().  This LOG(FATAL)'s so it should only be called in cases
247   // where a mismatching type is a programmer bug.
248   void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
249 
250   // Perform an X86 AVX style horizontal add between `lhs` and `rhs`.  The
251   // resulting IR for an 8-float wide vector is expected to lower to a single
252   // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
253   // other cases.
254   //
255   // For a vector width of 8, the result vector is computed as:
256   //   Result[0] = Lhs[0] + Lhs[1]
257   //   Result[1] = Lhs[2] + Lhs[3]
258   //   Result[2] = Rhs[0] + Rhs[1]
259   //   Result[3] = Rhs[2] + Rhs[3]
260   //   Result[4] = Lhs[4] + Lhs[5]
261   //   Result[5] = Lhs[6] + Lhs[7]
262   //   Result[6] = Rhs[4] + Rhs[5]
263   //   Result[7] = Rhs[6] + Rhs[7]
264   llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
265 
266   std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
267       std::vector<llvm::Value*> vectors, llvm::Value* init_values);
268 
269   llvm::Type* IntegerTypeForFloatSize(bool vector);
270   llvm::Value* I1ToFloat(llvm::Value* i1);
GetConstantFloat(llvm::Type * type,const llvm::APFloat & f)271   llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
272     llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
273     if (llvm::isa<llvm::VectorType>(type)) {
274       return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
275     }
276     return scalar_value;
277   }
278 
279   int64 vector_size_;
280   PrimitiveType primitive_type_;
281   llvm::IRBuilder<>* b_;
282   llvm::Type* vector_type_;
283   llvm::Type* vector_pointer_type_;
284   llvm::Type* scalar_type_;
285   llvm::Type* scalar_pointer_type_;
286   std::string name_;
287 };
288 
289 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
290 // can later convert to a SSA value.
291 class LlvmVariable {
292  public:
293   LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
294 
295   llvm::Value* Get() const;
296   void Set(llvm::Value* new_value);
297 
298  private:
299   llvm::AllocaInst* alloca_;
300   llvm::IRBuilder<>* b_;
301 };
302 
303 class VectorVariable : public LlvmVariable {
304  public:
VectorVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)305   VectorVariable(VectorSupportLibrary* vector_support,
306                  llvm::Value* initial_value)
307       : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
308     Set(initial_value);
309   }
310 };
311 
312 class ScalarVariable : public LlvmVariable {
313  public:
ScalarVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)314   ScalarVariable(VectorSupportLibrary* vector_support,
315                  llvm::Value* initial_value)
316       : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
317     Set(initial_value);
318   }
319 };
320 
321 // This wraps a set of alloca-backed stack variables that can, as a whole, store
322 // a tile.  A "tile" is a sequence of vectors that is typically used as a 2D
323 // grid of scalar values (e.g. for tiled GEMMs).
324 class TileVariable {
325  public:
326   TileVariable(VectorSupportLibrary* vector_support,
327                std::vector<llvm::Value*> initial_value);
328 
329   std::vector<llvm::Value*> Get() const;
330   void Set(absl::Span<llvm::Value* const> value);
331 
332  private:
333   std::vector<VectorVariable> storage_;
334 };
335 }  // namespace cpu
336 }  // namespace xla
337 
338 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
339