• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
18 
19 #include <map>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 namespace llvm_ir {
37 
38 // IrArray represents an XLA array at the LLVM IR level. This class
39 // encapsulates a base pointer to the buffer holding the array (as an LLVM
40 // Value) and the shape of the array. The class includes methods for emitting
41 // LLVM IR sequences which access elements of the array at a multidimensional
42 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts
43 // are supported.
44 class IrArray {
45  public:
46   // A multidimensional index into an IrArray. All the runtime indices
47   // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64>)
48   // are major-first.
49   //
50   // This may also keep a linear index and the layout and dimensions it was
51   // emitted for; if the shape where this `Index` is used matches, the linear
52   // index may be used, potentially sparing the cost of computing the
53   // multidimensional index, which LLVM DCE can delete.
54   class Index {
55    public:
56     // Constructs an index for a scalar shape.
Index(llvm::Type * index_ty)57     explicit Index(llvm::Type* index_ty) : index_type_(index_ty) {
58       CHECK(index_ty->isIntegerTy());
59     }
60 
61     // Constructs an index from linear index "linear" and computes the
62     // multi-dimensional index from "linear" and "shape". "b" is the IR
63     // builder to emit the index of each dimension in the multi-dimensional
64     // index.
65     //
66     // Precondition: "shape" has a layout.
67     Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
68 
69     // Constructs an index from a multi-dimensional index. 'shape' is the shape
70     // for which the multi-dimensional index is used. 'index_type' is the type
71     // of the index.
72     //
73     // Precondition: "shape" has a layout.
74     Index(absl::Span<llvm::Value* const> multidim, const Shape& shape,
75           llvm::Type* index_type);
76 
77     // Same as above, but only the dimensions of the shape without layout is
78     // passed. The layout is assumed to be the default (descending
79     // minor-to-major) layout.
80     Index(absl::Span<llvm::Value* const> multidim,
81           absl::Span<int64 const> dimensions, llvm::Type* index_type);
82 
83     // Returns an index that adds `addend` to the given `dim` of the object.
AddOffsetToDim(llvm::Value * addend,int64 dim,llvm::IRBuilder<> * b)84     Index AddOffsetToDim(llvm::Value* addend, int64 dim,
85                          llvm::IRBuilder<>* b) const {
86       Index with_offset = *this;
87       with_offset.linear_ = nullptr;
88       with_offset.multidim_[dim] =
89           b->CreateAdd(with_offset.multidim_[dim], addend);
90       return with_offset;
91     }
92 
multidim()93     const std::vector<llvm::Value*>& multidim() const { return multidim_; }
dims()94     const std::vector<int64>& dims() const { return dims_; }
linear()95     llvm::Value* linear() const { return linear_; }
96 
size()97     size_t size() const { return multidim().size(); }
98 
99     llvm::Value* operator[](size_t i) const { return multidim()[i]; }
100 
101     using const_iterator = std::vector<llvm::Value*>::const_iterator;
102 
begin()103     const_iterator begin() const { return multidim().begin(); }
end()104     const_iterator end() const { return multidim().end(); }
105 
106     bool LinearValidOnShape(const Shape& a) const;
107 
ShapeIsCompatible(const Shape & a)108     bool ShapeIsCompatible(const Shape& a) const {
109       Shape own_shape = ShapeUtil::MakeShape(a.element_type(), dims_);
110       *own_shape.mutable_layout() = layout_;
111       // The shape 'a' could have dynamic dimensions set. Before we check for
112       // equality, we need to copy the information which dimensions are dynamic.
113       for (int64 i = 0; i < a.rank(); ++i) {
114         own_shape.set_dynamic_dimension(i, a.is_dynamic_dimension(i));
115       }
116       return ShapeUtil::Equal(own_shape, a);
117     }
118 
119     // Given that "this" is the target index of a reshape from `input_shape`
120     // to `output_shape`, returns the source index.
121     Index SourceIndexOfReshape(const Shape& output_shape,
122                                const Shape& input_shape,
123                                llvm::IRBuilder<>* builder) const;
124 
125     // Returns the index into the source operand from which a slice operation
126     // selects a value to be placed into index "this". The slice is described
127     // by starting indices `starts` and stride values `strides`.
128     //
129     // Precondition: "this" is an index into a slice whose operand shape is
130     // `operand_shape`.
131     Index SourceIndexOfSlice(const Shape& operand_shape,
132                              absl::Span<const int64> starts,
133                              absl::Span<const int64> strides,
134                              llvm::IRBuilder<>* builder) const;
135 
136     // Given that "this" is the target index of a transpose from `operand_shape`
137     // to `shape` with the given dimension mapping, returns the source index.
138     Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape,
139                                  absl::Span<const int64> dimension_mapping,
140                                  llvm::IRBuilder<>* builder) const;
141 
142     // Given that "this" is the target index of a bitcast from `operand_shape`
143     // to `shape`, returns the source index.
144     Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape,
145                                llvm::IRBuilder<>* builder) const;
146 
147     // Given that "this" is the target index of a broadcast from `operand_shape`
148     // to `shape` with the given dimension mapping, returns the source index.
149     Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape,
150                                  absl::Span<const int64> dimension_mapping,
151                                  llvm::IRBuilder<>* builder) const;
152 
153     // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
154     // returns the index into the sole dimension 0 of the new shape.
155     llvm::Value* Linearize(absl::Span<const int64> dimensions,
156                            llvm::IRBuilder<>* builder) const;
157 
GetType()158     llvm::Type* GetType() const { return index_type_; }
159 
GetConstantWithIndexType(int64 c)160     llvm::Constant* GetConstantWithIndexType(int64 c) const {
161       // The LLVM function makes sure that the value can be represented by the
162       // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const
163       // APInt &V).
164       return llvm::ConstantInt::get(index_type_, c);
165     }
166 
167    private:
168     // Constructs an index from both a multi-dimensional index and a linear
169     // index. 'shape' is the shape on which the index is used. 'index_type' is
170     // the type of the index.
171     //
172     // Precondition: "shape" has a layout.
173     Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear,
174           const Shape& shape, llvm::Type* index_type);
175 
176     void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
177                      const Shape& shape, llvm::IRBuilder<>* b) const;
178 
179     std::vector<llvm::Value*> multidim_;
180 
181     // These values are purely for efficiency; `multidim_` is enough to find the
182     // element at a given `Index`, but if a loop is emitted with a linear index
183     // space, that linear index can be saved in `linear_`, and the layout and
184     // dimensions of the shape the loop was emitted for in `layout_` and
185     // `dims_`, and if the `Index` is used in another array, and its layout and
186     // dimensions match, the linear index can be used, sparing the cost of
187     // computing `multidim_`, which LLVM DCE could potentially so delete.
188     // Modifying `multidim_` after construction nullifies `linear_`, lest it
189     // be used wrongly, as it would be valid no more.
190     // If a loop is emitted with a multidimensional index space, `linear_` would
191     // be null and `layout_` and `dims_` would be ignored.
192     llvm::Value* linear_ = nullptr;
193     Layout layout_;
194     std::vector<int64> dims_;
195 
196     llvm::Type* index_type_;
197   };
198 
199   // Default constructor. Constructs an IrArray in a null status.
IrArray()200   IrArray() : base_ptr_(nullptr) {}
201 
202   // Construct an IrArray with the given base pointer and shape. base_ptr is a
203   // pointer type pointing to the first element(lowest address) of the array.
204   IrArray(llvm::Value* base_ptr, Shape shape);
205 
206   // Default implementations of copying and moving.
207   IrArray(IrArray&& other) = default;
208   IrArray(const IrArray& other) = default;
209   IrArray& operator=(IrArray&& other) = default;
210   IrArray& operator=(const IrArray& other) = default;
211 
GetBasePointer()212   llvm::Value* GetBasePointer() const { return base_ptr_; }
GetElementLlvmType()213   llvm::Type* GetElementLlvmType() const { return element_type_; }
214 
GetShape()215   const Shape& GetShape() const { return shape_; }
216 
217   // Emit a sequence of instructions to compute the address of the element in
218   // the given array at the given index. Returns the address of the element as
219   // an LLVM Value.
220   //
221   // The optional name is useful for debugging when looking at
222   // the emitted LLVM IR.
223   llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b,
224                                        absl::string_view name = "",
225                                        bool use_linear_index = true) const;
226 
227   // Attach metadata this IrArray instance knows about to "instruction".
228   void AnnotateLoadStoreInstructionWithMetadata(
229       llvm::Instruction* instruction) const;
230 
231   // Emit IR to read an array element at the given index. Returns the read
232   // result (effectively, a Value loaded from memory). This method seamlessly
233   // handles scalar shapes by broadcasting their value to all indices (index is
234   // ignored).
235   //
236   // The optional name is useful for debugging when looking at
237   // the emitted LLVM IR.
238   // 'use_linear_index' can be used to specify whether the linear index (if
239   // available) or the multi-dimensional index should be used.
240   llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b,
241                                     absl::string_view name = "",
242                                     bool use_linear_index = true) const;
243 
244   // Emit IR to write the given value to the array element at the given index.
245   // 'use_linear_index' can be used to specify whether the linear index (if
246   // available) or the multi-dimensional index should be used.
247   void EmitWriteArrayElement(const Index& index, llvm::Value* value,
248                              llvm::IRBuilder<>* b,
249                              bool use_linear_index = true) const;
250 
251   // Returns a new IrArray whose shape is "new_shape" and base pointer is a
252   // bitcast of the base pointer of "this" IrArray.
253   // 'use_linear_index' can be used to specify whether the linear index (if
254   // available) or the multi-dimensional index should be used.
255   IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const;
256 
AddAliasScopeMetadata(llvm::MDNode * alias_scope)257   void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
258     CHECK_NE(alias_scope, nullptr);
259     AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
260   }
261 
AddNoaliasMetadata(llvm::MDNode * noalias)262   void AddNoaliasMetadata(llvm::MDNode* noalias) {
263     CHECK_NE(noalias, nullptr);
264     AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
265   }
266 
267   // Promises LLVM that the data pointed to by this IrArray never changes after
268   // it's first loaded.
269   //
270   // The temporal scope of this promise is the "whole program" from LLVM's point
271   // of view, but how this translates to HLOs differs between backends.
272   //
273   // In the single-threaded CPU backend, we emit one function that
274   // runs all the HLOs in sequence, so the whole program is the whole HLO
275   // module.
276   //
277   // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO
278   // in the entry computation).  From LLVM's perspective, launching a new kernel
279   // is like launching a new program, and so the whole program is one top-level
280   // HLO.  Since the scope of the promise is smaller than in the CPU backend, we
281   // can mark more things as invariant in the GPU backend.
282   //
283   // Marking loads as invariant is particularly helpful on GPUs because
284   // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's
285   // __ldg intrinsic).  These loads use a special cache, and can be
286   // significantly faster than regular loads.
MarkInvariantOverWholeProgram(llvm::LLVMContext * context)287   void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) {
288     if (is_invariant_) {
289       return;
290     }
291     is_invariant_ = true;
292     AddMetadata(llvm::LLVMContext::MD_invariant_load,
293                 llvm::MDNode::get(*context, {}));
294   }
295 
metadata()296   const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; }
297 
298  private:
299   // Add the specified LLVM IR metadata to loads/stores associated with this
300   // IrArray.
AddMetadata(int kind,llvm::MDNode * md)301   void AddMetadata(int kind, llvm::MDNode* md) {
302     InsertOrDie(&metadata_, kind, md);
303   }
304 
305   // Address of the base of the array as an LLVM Value.
306   llvm::Value* base_ptr_;
307 
308   // The LLVM type of the elements in the array.
309   llvm::Type* element_type_;
310 
311   // Shape of the XLA array.
312   Shape shape_;
313 
314   // The list of key/value pairs used when attaching metadata to emitted
315   // loads/stores for this array.  They keys are the metadata kinds and the
316   // values are the metadata nodes.
317   std::map<int, llvm::MDNode*> metadata_;
318 
319   bool is_invariant_ = false;
320 };
321 
322 }  // namespace llvm_ir
323 }  // namespace xla
324 
325 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
326