• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 
34 namespace xla {
35 namespace llvm_ir {
36 
37 // IrArray represents an XLA array at the LLVM IR level. This class
38 // encapsulates a base pointer to the buffer holding the array (as an LLVM
39 // Value) and the shape of the array. The class includes methods for emitting
40 // LLVM IR sequences which access elements of the array at a multidimensional
41 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts
42 // are supported.
43 class IrArray {
44  public:
45   // A multidimensional index into an IrArray. All the runtime indices
46   // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64_t>)
47   // are major-first.
48   //
49   // This may also keep a linear index and the layout and dimensions it was
50   // emitted for; if the shape where this `Index` is used matches, the linear
51   // index may be used, potentially sparing the cost of computing the
52   // multidimensional index, which LLVM DCE can delete.
53   class Index {
54    public:
55     // Constructs an index for a scalar shape.
Index(llvm::Type * index_ty)56     explicit Index(llvm::Type* index_ty) : index_type_(index_ty) {
57       CHECK(index_ty->isIntegerTy());
58     }
59 
60     // Constructs an index from linear index "linear" and computes the
61     // multi-dimensional index from "linear" and "shape". "b" is the IR
62     // builder to emit the index of each dimension in the multi-dimensional
63     // index.
64     //
65     // Precondition: "shape" has a layout.
66     Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
67 
68     // As before, but also take a multidim to reuse.  multidim.size()
69     // == shape.rank() must be true.  If some of the multidim element
70     // are null we will use the value that would be used if
71     // deliearized from linear.
72     Index(llvm::Value* linear, absl::Span<llvm::Value* const> multidim,
73           const Shape& shape, llvm::IRBuilder<>* b);
74 
75     // Similar to the above constructor except using "dynamic_dims" instead of
76     // shape's static dimension to constructs the index.
77     Index(llvm::Value* linear, const Shape& shape,
78           absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b);
79 
80     // Constructs an index from a multi-dimensional index. 'shape' is the shape
81     // for which the multi-dimensional index is used. 'index_type' is the type
82     // of the index.
83     //
84     // Precondition: "shape" has a layout.
85     Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
86           llvm::Type* index_type);
87 
88     // Same as above, but only the dimensions of the shape without layout is
89     // passed. The layout is assumed to be the default (descending
90     // minor-to-major) layout.
91     Index(absl::Span<llvm::Value* const> multidim,
92           absl::Span<int64_t const> dimensions, llvm::Type* index_type);
93 
94     // Returns an index that adds `addend` to the given `dim` of the object.
AddOffsetToDim(llvm::Value * addend,int64_t dim,llvm::IRBuilder<> * b)95     Index AddOffsetToDim(llvm::Value* addend, int64_t dim,
96                          llvm::IRBuilder<>* b) const {
97       Index with_offset = *this;
98       with_offset.linear_ = nullptr;
99       with_offset.multidim_[dim] =
100           b->CreateAdd(with_offset.multidim_[dim], addend);
101       return with_offset;
102     }
103 
multidim()104     const std::vector<llvm::Value*>& multidim() const { return multidim_; }
dims()105     const std::vector<int64_t>& dims() const { return dims_; }
linear()106     llvm::Value* linear() const { return linear_; }
107 
size()108     size_t size() const { return multidim().size(); }
109 
110     llvm::Value* operator[](size_t i) const { return multidim()[i]; }
111 
112     using const_iterator = std::vector<llvm::Value*>::const_iterator;
113 
begin()114     const_iterator begin() const { return multidim().begin(); }
end()115     const_iterator end() const { return multidim().end(); }
116 
117     bool LinearValidOnShape(const Shape& a) const;
118 
119     static bool ShapeIsCompatible(const Shape& a, const Shape& b);
120 
ShapeIsCompatible(const Shape & a)121     bool ShapeIsCompatible(const Shape& a) const {
122       return ShapeIsCompatible(a, AsShapeWithType(a.element_type()));
123     }
124 
AsShapeWithType(PrimitiveType element_type)125     Shape AsShapeWithType(PrimitiveType element_type) const {
126       return ShapeUtil::MakeShapeWithLayout(element_type, dims_,
127                                             layout_.minor_to_major());
128     }
129 
130     // Given that "this" is the target index of a reshape from `input_shape`
131     // to `output_shape`, returns the source index.
132     Index SourceIndexOfReshape(const Shape& output_shape,
133                                const Shape& input_shape,
134                                llvm::IRBuilder<>* builder) const;
135 
136     // Returns the index into the source operand from which a slice operation
137     // selects a value to be placed into index "this". The slice is described
138     // by starting indices `starts` and stride values `strides`.
139     //
140     // Precondition: "this" is an index into a slice whose operand shape is
141     // `operand_shape`.
142     Index SourceIndexOfSlice(const Shape& operand_shape,
143                              absl::Span<const int64_t> starts,
144                              absl::Span<const int64_t> strides,
145                              llvm::IRBuilder<>* builder) const;
146 
147     // Given that "this" is the target index of a transpose from `operand_shape`
148     // to `shape` with the given dimension mapping, returns the source index.
149     Index SourceIndexOfTranspose(
150         const Shape& shape, const Shape& operand_shape,
151         absl::Span<const int64_t> dimension_mapping) const;
152 
153     // Given that "this" is the target index of a bitcast from `operand_shape`
154     // to `shape`, returns the source index.
155     Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
156                                llvm::IRBuilder<>* builder) const;
157 
158     // Given that "this" is the target index of a broadcast from `operand_shape`
159     // to `shape` with the given dimension mapping, returns the source index.
160     Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
161                                  absl::Span<const int64_t> dimension_mapping,
162                                  llvm::IRBuilder<>* builder) const;
163 
164     // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
165     // returns the index into the sole dimension 0 of the new shape.
166     llvm::Value* Linearize(absl::Span<const int64_t> dimensions,
167                            llvm::IRBuilder<>* builder) const;
168 
169     // Linearizes the index into the given dynamic dimensions.
170     llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims,
171                            llvm::IRBuilder<>* builder) const;
172 
GetType()173     llvm::Type* GetType() const { return index_type_; }
174 
GetConstantWithIndexType(int64_t c)175     llvm::Constant* GetConstantWithIndexType(int64_t c) const {
176       // The LLVM function makes sure that the value can be represented by the
177       // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const
178       // APInt &V).
179       return llvm::ConstantInt::get(index_type_, c);
180     }
181 
182    private:
183     // Constructs an index from both a multi-dimensional index and a linear
184     // index. 'shape' is the shape on which the index is used. 'index_type' is
185     // the type of the index.
186     //
187     // Precondition: "shape" has a layout.
188     Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
189           const Shape& shape, llvm::Type* index_type);
190 
191     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
192                      const Shape& shape, llvm::IRBuilder<>* b) const;
193 
194     // Delinearize the linear index with the dynamic dimensions.
195     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
196                      const Shape& shape, absl::Span<llvm::Value*> dynamic_dims,
197                      llvm::IRBuilder<>* b) const;
198 
199     std::vector<llvm::Value*> multidim_;
200 
201     // These values are purely for efficiency; `multidim_` is enough to find the
202     // element at a given `Index`, but if a loop is emitted with a linear index
203     // space, that linear index can be saved in `linear_`, and the layout and
204     // dimensions of the shape the loop was emitted for in `layout_` and
205     // `dims_`, and if the `Index` is used in another array, and its layout and
206     // dimensions match, the linear index can be used, sparing the cost of
207     // computing `multidim_`, which LLVM DCE could potentially so delete.
208     // Modifying `multidim_` after construction nullifies `linear_`, lest it
209     // be used wrongly, as it would be valid no more.
210     // If a loop is emitted with a multidimensional index space, `linear_` would
211     // be null and `layout_` and `dims_` would be ignored.
212     llvm::Value* linear_ = nullptr;
213     Layout layout_;
214     std::vector<int64_t> dims_;
215 
216     llvm::Type* index_type_;
217   };
218 
219   // Default constructor. Constructs an IrArray in a null status.
IrArray()220   IrArray() : base_ptr_(nullptr) {}
221 
222   // Construct an IrArray with the given base pointer, pointee type, and shape.
223   // base_ptr is a pointer type pointing to the first element(lowest address)
224   // of the array.
225   IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape);
226 
227   // Default implementations of copying and moving.
228   IrArray(IrArray&& other) = default;
229   IrArray(const IrArray& other) = default;
230   IrArray& operator=(IrArray&& other) = default;
231   IrArray& operator=(const IrArray& other) = default;
232 
GetBasePointer()233   llvm::Value* GetBasePointer() const { return base_ptr_; }
GetBasePointeeType()234   llvm::Type* GetBasePointeeType() const { return pointee_type_; }
GetElementLlvmType()235   llvm::Type* GetElementLlvmType() const { return element_type_; }
236 
GetShape()237   const Shape& GetShape() const { return shape_; }
238 
239   // Emit a sequence of instructions to compute the address of the element in
240   // the given array at the given index. Returns the address of the element as
241   // an LLVM Value.
242   //
243   // The optional name is useful for debugging when looking at
244   // the emitted LLVM IR.
245   llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
246                                        absl::string_view name = "",
247                                        bool use_linear_index = true) const;
248 
249   // Attach metadata this IrArray instance knows about to "instruction".
250   void AnnotateLoadStoreInstructionWithMetadata(
251       llvm::Instruction* instruction) const;
252 
253   // Emit IR to read an array element at the given index. Returns the read
254   // result (effectively, a Value loaded from memory). This method seamlessly
255   // handles scalar shapes by broadcasting their value to all indices (index is
256   // ignored).
257   //
258   // The optional name is useful for debugging when looking at
259   // the emitted LLVM IR.
260   // 'use_linear_index' can be used to specify whether the linear index (if
261   // available) or the multi-dimensional index should be used.
262   llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
263                                     absl::string_view name = "",
264                                     bool use_linear_index = true) const;
265 
266   // Emit IR to write the given value to the array element at the given index.
267   // 'use_linear_index' can be used to specify whether the linear index (if
268   // available) or the multi-dimensional index should be used.
269   void EmitWriteArrayElement(const Index& index, llvm::Value* value,
270                              llvm::IRBuilder<>* b,
271                              bool use_linear_index = true) const;
272 
273   // Returns a new IrArray whose shape is "new_shape" and base pointer is a
274   // bitcast of the base pointer of "this" IrArray.
275   // 'use_linear_index' can be used to specify whether the linear index (if
276   // available) or the multi-dimensional index should be used.
277   IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
278 
AddAliasScopeMetadata(llvm::MDNode * alias_scope)279   void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
280     CHECK_NE(alias_scope, nullptr);
281     AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
282   }
283 
AddNoaliasMetadata(llvm::MDNode * noalias)284   void AddNoaliasMetadata(llvm::MDNode* noalias) {
285     CHECK_NE(noalias, nullptr);
286     AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
287   }
288 
289   // Promises LLVM that the data pointed to by this IrArray never changes after
290   // it's first loaded.
291   //
292   // The temporal scope of this promise is the "whole program" from LLVM's point
293   // of view, but how this translates to HLOs differs between backends.
294   //
295   // In the single-threaded CPU backend, we emit one function that
296   // runs all the HLOs in sequence, so the whole program is the whole HLO
297   // module.
298   //
299   // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
300   // in the entry computation).  From LLVM's perspective, launching a new kernel
301   // is like launching a new program, and so the whole program is one top-level
302   // HLO.  Since the scope of the promise is smaller than in the CPU backend, we
303   // can mark more things as invariant in the GPU backend.
304   //
305   // Marking loads as invariant is particularly helpful on GPUs because
306   // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
307   // __ldg intrinsic).  These loads use a special cache, and can be
308   // significantly faster than regular loads.
MarkInvariantOverWholeProgram(llvm::LLVMContext * context)309   void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
310     if (is_invariant_) {
311       return;
312     }
313     is_invariant_ = true;
314     AddMetadata(llvm::LLVMContext::MD_invariant_load,
315                 llvm::MDNode::get(*context, {}));
316   }
317 
metadata()318   const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
319 
320  private:
321   // Add the specified LLVM IR metadata to loads/stores associated with this
322   // IrArray.
AddMetadata(int kind,llvm::MDNode * md)323   void AddMetadata(int kind, llvm::MDNode* md) {
324     InsertOrDie(&metadata_, kind, md);
325   }
326 
327   // Address of the base of the array as an LLVM Value.
328   llvm::Value* base_ptr_;
329 
330   // The pointee type of base_ptr_;
331   llvm::Type* pointee_type_;
332 
333   // The LLVM type of the elements in the array.
334   llvm::Type* element_type_;
335 
336   // Shape of the XLA array.
337   Shape shape_;
338 
339   // The list of key/value pairs used when attaching metadata to emitted
340   // loads/stores for this array.  They keys are the metadata kinds and the
341   // values are the metadata nodes.
342   std::map<int, llvm::MDNode*> metadata_;
343 
344   bool is_invariant_ = false;
345 };
346 
347 }  // namespace llvm_ir
348 }  // namespace xla
349 
350 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
351