• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 namespace llvm_ir {
37 
38 // IrArray represents an XLA array at the LLVM IR level. This class
39 // encapsulates a base pointer to the buffer holding the array (as an LLVM
40 // Value) and the shape of the array. The class includes methods for emitting
41 // LLVM IR sequences which access elements of the array at a multidimensional
42 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts
43 // are supported.
44 class IrArray {
45  public:
46   // A multidimensional index into an IrArray. All the runtime indices
47   // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64>)
48   // are major-first.
49   //
50   // This may also keep a linear index and the layout and dimensions it was
51   // emitted for; if the shape where this `Index` is used matches, the linear
52   // index may be used, potentially sparing the cost of computing the
53   // multidimensional index, which LLVM DCE can delete.
54   class Index {
55    public:
56     // Constructs an index for a scalar shape.
Index(llvm::Type * index_ty)57     explicit Index(llvm::Type* index_ty) : index_type_(index_ty) {
58       CHECK(index_ty->isIntegerTy());
59     }
60 
61     // Constructs an index from linear index "linear" and computes the
62     // multi-dimensional index from "linear" and "shape". "b" is the IR
63     // builder to emit the index of each dimension in the multi-dimensional
64     // index.
65     //
66     // Precondition: "shape" has a layout.
67     Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
68 
69     // Similar to the above constructor except using "dynamic_dims" instead of
70     // shape's static dimension to constructs the index.
71     Index(llvm::Value* linear, const Shape& shape,
72           absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b);
73 
74     // Constructs an index from a multi-dimensional index. 'shape' is the shape
75     // for which the multi-dimensional index is used. 'index_type' is the type
76     // of the index.
77     //
78     // Precondition: "shape" has a layout.
79     Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
80           llvm::Type* index_type);
81 
82     // Same as above, but only the dimensions of the shape without layout is
83     // passed. The layout is assumed to be the default (descending
84     // minor-to-major) layout.
85     Index(absl::Span<llvm::Value* const> multidim,
86           absl::Span<int64 const> dimensions, llvm::Type* index_type);
87 
88     // Returns an index that adds `addend` to the given `dim` of the object.
AddOffsetToDim(llvm::Value * addend,int64 dim,llvm::IRBuilder<> * b)89     Index AddOffsetToDim(llvm::Value* addend, int64 dim,
90                          llvm::IRBuilder<>* b) const {
91       Index with_offset = *this;
92       with_offset.linear_ = nullptr;
93       with_offset.multidim_[dim] =
94           b->CreateAdd(with_offset.multidim_[dim], addend);
95       return with_offset;
96     }
97 
multidim()98     const std::vector<llvm::Value*>& multidim() const { return multidim_; }
dims()99     const std::vector<int64>& dims() const { return dims_; }
linear()100     llvm::Value* linear() const { return linear_; }
101 
size()102     size_t size() const { return multidim().size(); }
103 
104     llvm::Value* operator[](size_t i) const { return multidim()[i]; }
105 
106     using const_iterator = std::vector<llvm::Value*>::const_iterator;
107 
begin()108     const_iterator begin() const { return multidim().begin(); }
end()109     const_iterator end() const { return multidim().end(); }
110 
111     bool LinearValidOnShape(const Shape& a) const;
112 
113     static bool ShapeIsCompatible(const Shape& a, const Shape& b);
114 
ShapeIsCompatible(const Shape & a)115     bool ShapeIsCompatible(const Shape& a) const {
116       return ShapeIsCompatible(
117           a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_,
118                                             layout_.minor_to_major()));
119     }
120 
121     // Given that "this" is the target index of a reshape from `input_shape`
122     // to `output_shape`, returns the source index.
123     Index SourceIndexOfReshape(const Shape& output_shape,
124                                const Shape& input_shape,
125                                llvm::IRBuilder<>* builder) const;
126 
127     // Returns the index into the source operand from which a slice operation
128     // selects a value to be placed into index "this". The slice is described
129     // by starting indices `starts` and stride values `strides`.
130     //
131     // Precondition: "this" is an index into a slice whose operand shape is
132     // `operand_shape`.
133     Index SourceIndexOfSlice(const Shape& operand_shape,
134                              absl::Span<const int64> starts,
135                              absl::Span<const int64> strides,
136                              llvm::IRBuilder<>* builder) const;
137 
138     // Given that "this" is the target index of a transpose from `operand_shape`
139     // to `shape` with the given dimension mapping, returns the source index.
140     Index SourceIndexOfTranspose(
141         const Shape& shape, const Shape& operand_shape,
142         absl::Span<const int64> dimension_mapping) const;
143 
144     // Given that "this" is the target index of a bitcast from `operand_shape`
145     // to `shape`, returns the source index.
146     Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
147                                llvm::IRBuilder<>* builder) const;
148 
149     // Given that "this" is the target index of a broadcast from `operand_shape`
150     // to `shape` with the given dimension mapping, returns the source index.
151     Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
152                                  absl::Span<const int64> dimension_mapping,
153                                  llvm::IRBuilder<>* builder) const;
154 
155     // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
156     // returns the index into the sole dimension 0 of the new shape.
157     llvm::Value* Linearize(absl::Span<const int64> dimensions,
158                            llvm::IRBuilder<>* builder) const;
159 
160     // Linearizes the index into the given dynamic dimensions.
161     llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims,
162                            llvm::IRBuilder<>* builder) const;
163 
GetType()164     llvm::Type* GetType() const { return index_type_; }
165 
GetConstantWithIndexType(int64 c)166     llvm::Constant* GetConstantWithIndexType(int64 c) const {
167       // The LLVM function makes sure that the value can be represented by the
168       // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const
169       // APInt &V).
170       return llvm::ConstantInt::get(index_type_, c);
171     }
172 
173    private:
174     // Constructs an index from both a multi-dimensional index and a linear
175     // index. 'shape' is the shape on which the index is used. 'index_type' is
176     // the type of the index.
177     //
178     // Precondition: "shape" has a layout.
179     Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
180           const Shape& shape, llvm::Type* index_type);
181 
182     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
183                      const Shape& shape, llvm::IRBuilder<>* b) const;
184 
185     // Delinearize the linear index with the dynamic dimensions.
186     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
187                      const Shape& shape, absl::Span<llvm::Value*> dynamic_dims,
188                      llvm::IRBuilder<>* b) const;
189 
190     std::vector<llvm::Value*> multidim_;
191 
192     // These values are purely for efficiency; `multidim_` is enough to find the
193     // element at a given `Index`, but if a loop is emitted with a linear index
194     // space, that linear index can be saved in `linear_`, and the layout and
195     // dimensions of the shape the loop was emitted for in `layout_` and
196     // `dims_`, and if the `Index` is used in another array, and its layout and
197     // dimensions match, the linear index can be used, sparing the cost of
198     // computing `multidim_`, which LLVM DCE could potentially so delete.
199     // Modifying `multidim_` after construction nullifies `linear_`, lest it
200     // be used wrongly, as it would be valid no more.
201     // If a loop is emitted with a multidimensional index space, `linear_` would
202     // be null and `layout_` and `dims_` would be ignored.
203     llvm::Value* linear_ = nullptr;
204     Layout layout_;
205     std::vector<int64> dims_;
206 
207     llvm::Type* index_type_;
208   };
209 
210   // Default constructor. Constructs an IrArray in a null status.
IrArray()211   IrArray() : base_ptr_(nullptr) {}
212 
213   // Construct an IrArray with the given base pointer and shape. base_ptr is a
214   // pointer type pointing to the first element(lowest address) of the array.
215   IrArray(llvm::Value* base_ptr, Shape shape);
216 
217   // Default implementations of copying and moving.
218   IrArray(IrArray&& other) = default;
219   IrArray(const IrArray& other) = default;
220   IrArray& operator=(IrArray&& other) = default;
221   IrArray& operator=(const IrArray& other) = default;
222 
GetBasePointer()223   llvm::Value* GetBasePointer() const { return base_ptr_; }
GetElementLlvmType()224   llvm::Type* GetElementLlvmType() const { return element_type_; }
225 
GetShape()226   const Shape& GetShape() const { return shape_; }
227 
228   // Emit a sequence of instructions to compute the address of the element in
229   // the given array at the given index. Returns the address of the element as
230   // an LLVM Value.
231   //
232   // The optional name is useful for debugging when looking at
233   // the emitted LLVM IR.
234   llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
235                                        absl::string_view name = "",
236                                        bool use_linear_index = true) const;
237 
238   // Attach metadata this IrArray instance knows about to "instruction".
239   void AnnotateLoadStoreInstructionWithMetadata(
240       llvm::Instruction* instruction) const;
241 
242   // Emit IR to read an array element at the given index. Returns the read
243   // result (effectively, a Value loaded from memory). This method seamlessly
244   // handles scalar shapes by broadcasting their value to all indices (index is
245   // ignored).
246   //
247   // The optional name is useful for debugging when looking at
248   // the emitted LLVM IR.
249   // 'use_linear_index' can be used to specify whether the linear index (if
250   // available) or the multi-dimensional index should be used.
251   llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
252                                     absl::string_view name = "",
253                                     bool use_linear_index = true) const;
254 
255   // Emit IR to write the given value to the array element at the given index.
256   // 'use_linear_index' can be used to specify whether the linear index (if
257   // available) or the multi-dimensional index should be used.
258   void EmitWriteArrayElement(const Index& index, llvm::Value* value,
259                              llvm::IRBuilder<>* b,
260                              bool use_linear_index = true) const;
261 
262   // Returns a new IrArray whose shape is "new_shape" and base pointer is a
263   // bitcast of the base pointer of "this" IrArray.
264   // 'use_linear_index' can be used to specify whether the linear index (if
265   // available) or the multi-dimensional index should be used.
266   IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
267 
AddAliasScopeMetadata(llvm::MDNode * alias_scope)268   void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
269     CHECK_NE(alias_scope, nullptr);
270     AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
271   }
272 
AddNoaliasMetadata(llvm::MDNode * noalias)273   void AddNoaliasMetadata(llvm::MDNode* noalias) {
274     CHECK_NE(noalias, nullptr);
275     AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
276   }
277 
278   // Promises LLVM that the data pointed to by this IrArray never changes after
279   // it's first loaded.
280   //
281   // The temporal scope of this promise is the "whole program" from LLVM's point
282   // of view, but how this translates to HLOs differs between backends.
283   //
284   // In the single-threaded CPU backend, we emit one function that
285   // runs all the HLOs in sequence, so the whole program is the whole HLO
286   // module.
287   //
288   // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
289   // in the entry computation).  From LLVM's perspective, launching a new kernel
290   // is like launching a new program, and so the whole program is one top-level
291   // HLO.  Since the scope of the promise is smaller than in the CPU backend, we
292   // can mark more things as invariant in the GPU backend.
293   //
294   // Marking loads as invariant is particularly helpful on GPUs because
295   // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
296   // __ldg intrinsic).  These loads use a special cache, and can be
297   // significantly faster than regular loads.
MarkInvariantOverWholeProgram(llvm::LLVMContext * context)298   void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
299     if (is_invariant_) {
300       return;
301     }
302     is_invariant_ = true;
303     AddMetadata(llvm::LLVMContext::MD_invariant_load,
304                 llvm::MDNode::get(*context, {}));
305   }
306 
metadata()307   const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
308 
309  private:
310   // Add the specified LLVM IR metadata to loads/stores associated with this
311   // IrArray.
AddMetadata(int kind,llvm::MDNode * md)312   void AddMetadata(int kind, llvm::MDNode* md) {
313     InsertOrDie(&metadata_, kind, md);
314   }
315 
316   // Address of the base of the array as an LLVM Value.
317   llvm::Value* base_ptr_;
318 
319   // The LLVM type of the elements in the array.
320   llvm::Type* element_type_;
321 
322   // Shape of the XLA array.
323   Shape shape_;
324 
325   // The list of key/value pairs used when attaching metadata to emitted
326   // loads/stores for this array.  They keys are the metadata kinds and the
327   // values are the metadata nodes.
328   std::map<int, llvm::MDNode*> metadata_;
329 
330   bool is_invariant_ = false;
331 };
332 
333 }  // namespace llvm_ir
334 }  // namespace xla
335 
336 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
337