1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18
19 #include <string>
20
21 #include "absl/types/span.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29 namespace cpu {
30
31 // Simple wrappers around llvm::APFloat::APFloat to make the calling code more
32 // obvious.
33
GetIeeeF32(float f)34 inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
GetIeeeF32FromBitwiseRep(int32 bitwise_value)35 inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
36 return llvm::APFloat(llvm::APFloat::IEEEsingle(),
37 llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
38 }
39
40 // A thin wrapper around llvm_util.h to make code generating vector math flow
41 // more readable.
42 class VectorSupportLibrary {
43 public:
44 // This VectorSupportLibrary instance remembers `primitive_type` and
45 // `vector_size`, and these are implicitly used by the methods on this
46 // instance (i.e. LoadVector will load a vector of type <`vector_size` x
47 // `primitive_type`>).
48 VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
49 llvm::IRBuilder<>* b, std::string name);
50
51 llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
Mul(int64 lhs,llvm::Value * rhs)52 llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
53 return Mul(b()->getInt64(lhs), rhs);
54 }
Mul(const llvm::APFloat & lhs,llvm::Value * rhs)55 llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
56 return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
57 }
58
59 // If your call resolved to these then you probably wanted the versions taking
60 // APFloat.
61 llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
62 llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
63
64 llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
Add(int64 lhs,llvm::Value * rhs)65 llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
66 return Add(b()->getInt64(lhs), rhs);
67 }
Add(const llvm::APFloat & lhs,llvm::Value * rhs)68 llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
69 return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
70 }
71
72 // If your call resolved to these then you probably wanted the versions taking
73 // APFloat.
74 llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
75 llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
76
77 llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
Sub(llvm::Value * lhs,const llvm::APFloat & rhs)78 llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
79 return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
80 }
81 llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
Max(const llvm::APFloat & lhs,llvm::Value * rhs)82 llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
83 return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
84 }
85 llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
86
MulAdd(llvm::Value * a,llvm::Value * b,llvm::Value * c)87 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
88 return Add(c, Mul(a, b));
89 }
90
MulAdd(llvm::Value * a,llvm::Value * b,const llvm::APFloat & c)91 llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
92 return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
93 }
94
MulAdd(llvm::Value * a,const llvm::APFloat & b,const llvm::APFloat & c)95 llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
96 const llvm::APFloat& c) {
97 return Add(GetConstantFloat(a->getType(), c),
98 Mul(a, GetConstantFloat(a->getType(), b)));
99 }
100
101 llvm::Value* Floor(llvm::Value* a);
102
103 llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
104 const llvm::APFloat& high);
SplatFloat(const llvm::APFloat & d)105 llvm::Value* SplatFloat(const llvm::APFloat& d) {
106 return GetConstantFloat(vector_type(), d);
107 }
108
109 // These compare instructions return a floating point typed mask instead of an
110 // i1. For instance, on a vector typed input, lanes where the predicate is
111 // true get a float with all ones and other lanes get a float with all zeros.
112 // This is slightly odd from the perspective of LLVM's type system, but it
113 // makes kernel IR generation code written using VectorSupportLibrary (its
114 // raison d'etre) less cluttered.
115
116 llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpEQMask(llvm::Value * lhs,const llvm::APFloat & rhs)117 llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
118 return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
119 }
120 llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
121 llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
FCmpOLTMask(llvm::Value * lhs,const llvm::APFloat & rhs)122 llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
123 return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
124 }
125
126 // These boolean operations operate on the bitwise values of the floating
127 // point inputs. They return a (vector of) float(s) but like in the mask
128 // generating predicates above this type system oddity makes the kernel IR
129 // generation code less cluttered.
130 llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
FloatAnd(llvm::Value * lhs,const llvm::APFloat & rhs)131 llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
132 return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
133 }
134 llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
FloatOr(llvm::Value * lhs,const llvm::APFloat & rhs)135 llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
136 return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
137 }
138 llvm::Value* FloatNot(llvm::Value* lhs);
FloatAndNot(llvm::Value * lhs,llvm::Value * rhs)139 llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
140 return FloatAnd(FloatNot(lhs), rhs);
141 }
142
143 llvm::Value* BroadcastScalar(llvm::Value* x);
BroadcastScalar(const llvm::APFloat & d)144 llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
145 return BroadcastScalar(GetConstantFloat(scalar_type(), d));
146 }
147
148 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
149 llvm::Value* offset_elements);
ComputeOffsetPointer(llvm::Value * base_pointer,llvm::Value * offset_elements,int64 scale)150 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
151 llvm::Value* offset_elements, int64 scale) {
152 return ComputeOffsetPointer(
153 base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements));
154 }
ComputeOffsetPointer(llvm::Value * base_pointer,int64 offset_elements)155 llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
156 int64 offset_elements) {
157 return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements));
158 }
159
160 llvm::Value* LoadVector(llvm::Value* pointer);
161
LoadVector(llvm::Value * base_pointer,llvm::Value * offset_elements)162 llvm::Value* LoadVector(llvm::Value* base_pointer,
163 llvm::Value* offset_elements) {
164 return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
165 }
166
LoadVector(llvm::Value * base_pointer,int64 offset_elements)167 llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
168 return LoadVector(base_pointer, b()->getInt64(offset_elements));
169 }
170
171 llvm::Value* LoadScalar(llvm::Value* pointer);
172
LoadScalar(llvm::Value * base_pointer,llvm::Value * offset_elements)173 llvm::Value* LoadScalar(llvm::Value* base_pointer,
174 llvm::Value* offset_elements) {
175 return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
176 }
177
LoadScalar(llvm::Value * base_pointer,int64 offset_elements)178 llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
179 return LoadScalar(base_pointer, b()->getInt64(offset_elements));
180 }
181
182 void StoreVector(llvm::Value* value, llvm::Value* pointer);
183
StoreVector(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)184 void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
185 llvm::Value* offset_elements) {
186 StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
187 }
188
StoreVector(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)189 void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
190 int64 offset_elements) {
191 StoreVector(value, base_pointer, b()->getInt64(offset_elements));
192 }
193
194 void StoreScalar(llvm::Value* value, llvm::Value* pointer);
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,llvm::Value * offset_elements)195 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
196 llvm::Value* offset_elements) {
197 StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
198 }
199
StoreScalar(llvm::Value * value,llvm::Value * base_pointer,int64 offset_elements)200 void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
201 int64 offset_elements) {
202 StoreScalar(base_pointer, b()->getInt64(offset_elements));
203 }
204
205 llvm::Value* LoadBroadcast(llvm::Value* pointer);
LoadBroadcast(llvm::Value * base_pointer,llvm::Value * offset_elements)206 llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
207 llvm::Value* offset_elements) {
208 return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
209 }
LoadBroadcast(llvm::Value * base_pointer,int64 offset_elements)210 llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
211 return LoadBroadcast(base_pointer, b()->getInt64(offset_elements));
212 }
213
214 // Compute the horizontal sum of each vector in `vectors`. The i'th element
215 // in the result vector is the (scalar) horizontal sum of the i'th vector in
216 // `vectors`. If `init_values` is not nullptr then the value in the i'th lane
217 // in `init_values` is added to the i'th horizontal sum.
218 std::vector<llvm::Value*> ComputeHorizontalSums(
219 std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
220
221 llvm::Value* GetZeroVector();
222 llvm::Value* GetZeroScalar();
223
b()224 llvm::IRBuilder<>* b() const { return b_; }
vector_size()225 int64 vector_size() const { return vector_size_; }
vector_type()226 llvm::Type* vector_type() const { return vector_type_; }
vector_pointer_type()227 llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
scalar_type()228 llvm::Type* scalar_type() const { return scalar_type_; }
scalar_pointer_type()229 llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
scalar_byte_size()230 int64 scalar_byte_size() const {
231 return primitive_util::BitWidth(primitive_type_) / 8;
232 }
233
name()234 const std::string& name() const { return name_; }
235
236 private:
237 llvm::Value* ExtractLowHalf(llvm::Value*);
238 llvm::Value* ExtractHighHalf(llvm::Value*);
239
240 llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
241 llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
242
243 llvm::Value* AddReduce(llvm::Value* vector);
244
245 // Checks that each value in `values` is either of type scalar_type() or
246 // vector_type(). This LOG(FATAL)'s so it should only be called in cases
247 // where a mismatching type is a programmer bug.
248 void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
249
250 // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The
251 // resulting IR for an 8-float wide vector is expected to lower to a single
252 // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
253 // other cases.
254 //
255 // For a vector width of 8, the result vector is computed as:
256 // Result[0] = Lhs[0] + Lhs[1]
257 // Result[1] = Lhs[2] + Lhs[3]
258 // Result[2] = Rhs[0] + Rhs[1]
259 // Result[3] = Rhs[2] + Rhs[3]
260 // Result[4] = Lhs[4] + Lhs[5]
261 // Result[5] = Lhs[6] + Lhs[7]
262 // Result[6] = Rhs[4] + Rhs[5]
263 // Result[7] = Rhs[6] + Rhs[7]
264 llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
265
266 std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
267 std::vector<llvm::Value*> vectors, llvm::Value* init_values);
268
269 llvm::Type* IntegerTypeForFloatSize(bool vector);
270 llvm::Value* I1ToFloat(llvm::Value* i1);
GetConstantFloat(llvm::Type * type,const llvm::APFloat & f)271 llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
272 llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
273 if (llvm::isa<llvm::VectorType>(type)) {
274 return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
275 }
276 return scalar_value;
277 }
278
279 int64 vector_size_;
280 PrimitiveType primitive_type_;
281 llvm::IRBuilder<>* b_;
282 llvm::Type* vector_type_;
283 llvm::Type* vector_pointer_type_;
284 llvm::Type* scalar_type_;
285 llvm::Type* scalar_pointer_type_;
286 std::string name_;
287 };
288
289 // This wraps an alloca-backed stack variable which LLVM's SSA construction pass
290 // can later convert to a SSA value.
291 class LlvmVariable {
292 public:
293 LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b);
294
295 llvm::Value* Get() const;
296 void Set(llvm::Value* new_value);
297
298 private:
299 llvm::AllocaInst* alloca_;
300 llvm::IRBuilder<>* b_;
301 };
302
303 class VectorVariable : public LlvmVariable {
304 public:
VectorVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)305 VectorVariable(VectorSupportLibrary* vector_support,
306 llvm::Value* initial_value)
307 : LlvmVariable(vector_support->vector_type(), vector_support->b()) {
308 Set(initial_value);
309 }
310 };
311
312 class ScalarVariable : public LlvmVariable {
313 public:
ScalarVariable(VectorSupportLibrary * vector_support,llvm::Value * initial_value)314 ScalarVariable(VectorSupportLibrary* vector_support,
315 llvm::Value* initial_value)
316 : LlvmVariable(vector_support->scalar_type(), vector_support->b()) {
317 Set(initial_value);
318 }
319 };
320
321 // This wraps a set of alloca-backed stack variables that can, as a whole, store
322 // a tile. A "tile" is a sequence of vectors that is typically used as a 2D
323 // grid of scalar values (e.g. for tiled GEMMs).
324 class TileVariable {
325 public:
326 TileVariable(VectorSupportLibrary* vector_support,
327 std::vector<llvm::Value*> initial_value);
328
329 std::vector<llvm::Value*> Get() const;
330 void Set(absl::Span<llvm::Value* const> value);
331
332 private:
333 std::vector<VectorVariable> storage_;
334 };
335 } // namespace cpu
336 } // namespace xla
337
338 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
339