• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 namespace llvm_ir {
37 
38 // IrArray represents an XLA array at the LLVM IR level. This class
39 // encapsulates a base pointer to the buffer holding the array (as an LLVM
40 // Value) and the shape of the array. The class includes methods for emitting
41 // LLVM IR sequences which access elements of the array at a multidimensional
42 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts
43 // are supported.
44 class IrArray {
45  public:
46   // A multidimensional index into an IrArray. All the runtime indices
47   // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64>)
48   // are major-first.
49   //
50   // This may also keep a linear index and the layout and dimensions it was
51   // emitted for; if the shape where this `Index` is used matches, the linear
52   // index may be used, potentially sparing the cost of computing the
53   // multidimensional index, which LLVM DCE can delete.
54   class Index {
55    public:
56     // Constructs an index for a scalar shape.
Index(llvm::Type * index_ty)57     explicit Index(llvm::Type* index_ty) : index_type_(index_ty) {
58       CHECK(index_ty->isIntegerTy());
59     }
60 
61     // Constructs an index from linear index "linear" and computes the
62     // multi-dimensional index from "linear" and "shape". "b" is the IR
63     // builder to emit the index of each dimension in the multi-dimensional
64     // index.
65     //
66     // Precondition: "shape" has a layout.
67     Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
68 
69     // As before, but also take a multidim to reuse.  multidim.size()
70     // == shape.rank() must be true.  If some of the multidim element
71     // are null we will use the value that would be used if
72     // deliearized from linear.
73     Index(llvm::Value* linear, absl::Span<llvm::Value* const> multidim,
74           const Shape& shape, llvm::IRBuilder<>* b);
75 
76     // Similar to the above constructor except using "dynamic_dims" instead of
77     // shape's static dimension to constructs the index.
78     Index(llvm::Value* linear, const Shape& shape,
79           absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b);
80 
81     // Constructs an index from a multi-dimensional index. 'shape' is the shape
82     // for which the multi-dimensional index is used. 'index_type' is the type
83     // of the index.
84     //
85     // Precondition: "shape" has a layout.
86     Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
87           llvm::Type* index_type);
88 
89     // Same as above, but only the dimensions of the shape without layout is
90     // passed. The layout is assumed to be the default (descending
91     // minor-to-major) layout.
92     Index(absl::Span<llvm::Value* const> multidim,
93           absl::Span<int64 const> dimensions, llvm::Type* index_type);
94 
95     // Returns an index that adds `addend` to the given `dim` of the object.
AddOffsetToDim(llvm::Value * addend,int64_t dim,llvm::IRBuilder<> * b)96     Index AddOffsetToDim(llvm::Value* addend, int64_t dim,
97                          llvm::IRBuilder<>* b) const {
98       Index with_offset = *this;
99       with_offset.linear_ = nullptr;
100       with_offset.multidim_[dim] =
101           b->CreateAdd(with_offset.multidim_[dim], addend);
102       return with_offset;
103     }
104 
multidim()105     const std::vector<llvm::Value*>& multidim() const { return multidim_; }
dims()106     const std::vector<int64>& dims() const { return dims_; }
linear()107     llvm::Value* linear() const { return linear_; }
108 
size()109     size_t size() const { return multidim().size(); }
110 
111     llvm::Value* operator[](size_t i) const { return multidim()[i]; }
112 
113     using const_iterator = std::vector<llvm::Value*>::const_iterator;
114 
begin()115     const_iterator begin() const { return multidim().begin(); }
end()116     const_iterator end() const { return multidim().end(); }
117 
118     bool LinearValidOnShape(const Shape& a) const;
119 
120     static bool ShapeIsCompatible(const Shape& a, const Shape& b);
121 
ShapeIsCompatible(const Shape & a)122     bool ShapeIsCompatible(const Shape& a) const {
123       return ShapeIsCompatible(
124           a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_,
125                                             layout_.minor_to_major()));
126     }
127 
128     // Given that "this" is the target index of a reshape from `input_shape`
129     // to `output_shape`, returns the source index.
130     Index SourceIndexOfReshape(const Shape& output_shape,
131                                const Shape& input_shape,
132                                llvm::IRBuilder<>* builder) const;
133 
134     // Returns the index into the source operand from which a slice operation
135     // selects a value to be placed into index "this". The slice is described
136     // by starting indices `starts` and stride values `strides`.
137     //
138     // Precondition: "this" is an index into a slice whose operand shape is
139     // `operand_shape`.
140     Index SourceIndexOfSlice(const Shape& operand_shape,
141                              absl::Span<const int64> starts,
142                              absl::Span<const int64> strides,
143                              llvm::IRBuilder<>* builder) const;
144 
145     // Given that "this" is the target index of a transpose from `operand_shape`
146     // to `shape` with the given dimension mapping, returns the source index.
147     Index SourceIndexOfTranspose(
148         const Shape& shape, const Shape& operand_shape,
149         absl::Span<const int64> dimension_mapping) const;
150 
151     // Given that "this" is the target index of a bitcast from `operand_shape`
152     // to `shape`, returns the source index.
153     Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
154                                llvm::IRBuilder<>* builder) const;
155 
156     // Given that "this" is the target index of a broadcast from `operand_shape`
157     // to `shape` with the given dimension mapping, returns the source index.
158     Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
159                                  absl::Span<const int64> dimension_mapping,
160                                  llvm::IRBuilder<>* builder) const;
161 
162     // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
163     // returns the index into the sole dimension 0 of the new shape.
164     llvm::Value* Linearize(absl::Span<const int64> dimensions,
165                            llvm::IRBuilder<>* builder) const;
166 
167     // Linearizes the index into the given dynamic dimensions.
168     llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims,
169                            llvm::IRBuilder<>* builder) const;
170 
GetType()171     llvm::Type* GetType() const { return index_type_; }
172 
GetConstantWithIndexType(int64_t c)173     llvm::Constant* GetConstantWithIndexType(int64_t c) const {
174       // The LLVM function makes sure that the value can be represented by the
175       // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const
176       // APInt &V).
177       return llvm::ConstantInt::get(index_type_, c);
178     }
179 
180    private:
181     // Constructs an index from both a multi-dimensional index and a linear
182     // index. 'shape' is the shape on which the index is used. 'index_type' is
183     // the type of the index.
184     //
185     // Precondition: "shape" has a layout.
186     Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
187           const Shape& shape, llvm::Type* index_type);
188 
189     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
190                      const Shape& shape, llvm::IRBuilder<>* b) const;
191 
192     // Delinearize the linear index with the dynamic dimensions.
193     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
194                      const Shape& shape, absl::Span<llvm::Value*> dynamic_dims,
195                      llvm::IRBuilder<>* b) const;
196 
197     std::vector<llvm::Value*> multidim_;
198 
199     // These values are purely for efficiency; `multidim_` is enough to find the
200     // element at a given `Index`, but if a loop is emitted with a linear index
201     // space, that linear index can be saved in `linear_`, and the layout and
202     // dimensions of the shape the loop was emitted for in `layout_` and
203     // `dims_`, and if the `Index` is used in another array, and its layout and
204     // dimensions match, the linear index can be used, sparing the cost of
205     // computing `multidim_`, which LLVM DCE could potentially so delete.
206     // Modifying `multidim_` after construction nullifies `linear_`, lest it
207     // be used wrongly, as it would be valid no more.
208     // If a loop is emitted with a multidimensional index space, `linear_` would
209     // be null and `layout_` and `dims_` would be ignored.
210     llvm::Value* linear_ = nullptr;
211     Layout layout_;
212     std::vector<int64> dims_;
213 
214     llvm::Type* index_type_;
215   };
216 
217   // Default constructor. Constructs an IrArray in a null status.
IrArray()218   IrArray() : base_ptr_(nullptr) {}
219 
220   // Construct an IrArray with the given base pointer and shape. base_ptr is a
221   // pointer type pointing to the first element(lowest address) of the array.
222   IrArray(llvm::Value* base_ptr, Shape shape);
223 
224   // Default implementations of copying and moving.
225   IrArray(IrArray&& other) = default;
226   IrArray(const IrArray& other) = default;
227   IrArray& operator=(IrArray&& other) = default;
228   IrArray& operator=(const IrArray& other) = default;
229 
GetBasePointer()230   llvm::Value* GetBasePointer() const { return base_ptr_; }
GetElementLlvmType()231   llvm::Type* GetElementLlvmType() const { return element_type_; }
232 
GetShape()233   const Shape& GetShape() const { return shape_; }
234 
235   // Emit a sequence of instructions to compute the address of the element in
236   // the given array at the given index. Returns the address of the element as
237   // an LLVM Value.
238   //
239   // The optional name is useful for debugging when looking at
240   // the emitted LLVM IR.
241   llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
242                                        absl::string_view name = "",
243                                        bool use_linear_index = true) const;
244 
245   // Attach metadata this IrArray instance knows about to "instruction".
246   void AnnotateLoadStoreInstructionWithMetadata(
247       llvm::Instruction* instruction) const;
248 
249   // Emit IR to read an array element at the given index. Returns the read
250   // result (effectively, a Value loaded from memory). This method seamlessly
251   // handles scalar shapes by broadcasting their value to all indices (index is
252   // ignored).
253   //
254   // The optional name is useful for debugging when looking at
255   // the emitted LLVM IR.
256   // 'use_linear_index' can be used to specify whether the linear index (if
257   // available) or the multi-dimensional index should be used.
258   llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
259                                     absl::string_view name = "",
260                                     bool use_linear_index = true) const;
261 
262   // Emit IR to write the given value to the array element at the given index.
263   // 'use_linear_index' can be used to specify whether the linear index (if
264   // available) or the multi-dimensional index should be used.
265   void EmitWriteArrayElement(const Index& index, llvm::Value* value,
266                              llvm::IRBuilder<>* b,
267                              bool use_linear_index = true) const;
268 
269   // Returns a new IrArray whose shape is "new_shape" and base pointer is a
270   // bitcast of the base pointer of "this" IrArray.
271   // 'use_linear_index' can be used to specify whether the linear index (if
272   // available) or the multi-dimensional index should be used.
273   IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
274 
AddAliasScopeMetadata(llvm::MDNode * alias_scope)275   void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
276     CHECK_NE(alias_scope, nullptr);
277     AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
278   }
279 
AddNoaliasMetadata(llvm::MDNode * noalias)280   void AddNoaliasMetadata(llvm::MDNode* noalias) {
281     CHECK_NE(noalias, nullptr);
282     AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
283   }
284 
285   // Promises LLVM that the data pointed to by this IrArray never changes after
286   // it's first loaded.
287   //
288   // The temporal scope of this promise is the "whole program" from LLVM's point
289   // of view, but how this translates to HLOs differs between backends.
290   //
291   // In the single-threaded CPU backend, we emit one function that
292   // runs all the HLOs in sequence, so the whole program is the whole HLO
293   // module.
294   //
295   // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
296   // in the entry computation).  From LLVM's perspective, launching a new kernel
297   // is like launching a new program, and so the whole program is one top-level
298   // HLO.  Since the scope of the promise is smaller than in the CPU backend, we
299   // can mark more things as invariant in the GPU backend.
300   //
301   // Marking loads as invariant is particularly helpful on GPUs because
302   // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
303   // __ldg intrinsic).  These loads use a special cache, and can be
304   // significantly faster than regular loads.
MarkInvariantOverWholeProgram(llvm::LLVMContext * context)305   void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
306     if (is_invariant_) {
307       return;
308     }
309     is_invariant_ = true;
310     AddMetadata(llvm::LLVMContext::MD_invariant_load,
311                 llvm::MDNode::get(*context, {}));
312   }
313 
metadata()314   const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
315 
316  private:
317   // Add the specified LLVM IR metadata to loads/stores associated with this
318   // IrArray.
AddMetadata(int kind,llvm::MDNode * md)319   void AddMetadata(int kind, llvm::MDNode* md) {
320     InsertOrDie(&metadata_, kind, md);
321   }
322 
323   // Address of the base of the array as an LLVM Value.
324   llvm::Value* base_ptr_;
325 
326   // The LLVM type of the elements in the array.
327   llvm::Type* element_type_;
328 
329   // Shape of the XLA array.
330   Shape shape_;
331 
332   // The list of key/value pairs used when attaching metadata to emitted
333   // loads/stores for this array.  They keys are the metadata kinds and the
334   // values are the metadata nodes.
335   std::map<int, llvm::MDNode*> metadata_;
336 
337   bool is_invariant_ = false;
338 };
339 
340 }  // namespace llvm_ir
341 }  // namespace xla
342 
343 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
344