1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/shape.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 namespace llvm_ir { 36 37 // IrArray represents an XLA array at the LLVM IR level. This class 38 // encapsulates a base pointer to the buffer holding the array (as an LLVM 39 // Value) and the shape of the array. The class includes methods for emitting 40 // LLVM IR sequences which access elements of the array at a multidimensional 41 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts 42 // are supported. 43 class IrArray { 44 public: 45 // A multidimensional index into an IrArray. The index for dimension zero is 46 // first in the vector. This is the reverse order of the notation used for 47 // describing the dimensions of an array. That is, for a [4 x 3 x 2] array 48 // dimension zero has size 2, dimension one has size 3, and dimension two has 49 // size 4. Thus the index {1, 2, 3} indexes the last element of this [4 x 3 x 50 // 2] array. 51 // 52 // This may also keep a linear index and the layout and dimensions it was 53 // emitted for; if the shape where this `Index` is used matches, the linear 54 // index may be used, potentially sparing the cost of computing the 55 // multidimensional index, which LLVM DCE can delete. 56 class Index { 57 public: 58 // Constructs an index for a scalar shape. Index(llvm::Type * index_ty)59 explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { 60 CHECK(index_ty->isIntegerTy()); 61 } 62 63 // Constructs an index from multi-dimensional index "multidim". The linear 64 // index is set to nullptr. 65 explicit Index(absl::Span<llvm::Value* const> multidim, 66 llvm::Type* index_ty = nullptr) 67 : multidim_(multidim.begin(), multidim.end()) { 68 if (size() == 0) { 69 index_type_ = index_ty; 70 } else { 71 for (const auto* dim : multidim) { 72 CHECK_NE(dim, nullptr); 73 } 74 index_type_ = multidim[0]->getType(); 75 if (index_ty != nullptr) { 76 CHECK_EQ(index_type_, index_ty); 77 } 78 } 79 CHECK_NE(index_type_, nullptr); 80 CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { 81 return index_type_ == v->getType(); 82 })); 83 } 84 85 // Constructs an index from linear index "linear" and computes the 86 // multi-dimensional index from "linear" and "shape". "b" is the IR 87 // builder to emit the index of each dimension in the multi-dimensional 88 // index. 89 // 90 // Precondition: "shape" has a layout. 91 Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); 92 93 // Constructs an index from a multi-dimensional index. 'shape' is the shape 94 // for which the multi-dimensional index is used. 'index_type' is the type 95 // of the index. 96 // 97 // Precondition: "shape" has a layout. 98 Index(absl::Span<llvm::Value* const> multidim, const Shape& shape, 99 llvm::Type* index_type); 100 101 // Returns an index that adds `addend` to the given `dim` of the object. AddOffsetToDim(llvm::Value * addend,int64 dim,llvm::IRBuilder<> * b)102 Index AddOffsetToDim(llvm::Value* addend, int64 dim, 103 llvm::IRBuilder<>* b) const { 104 std::vector<llvm::Value*> multi_index = multidim(); 105 multi_index[dim] = b->CreateAdd(multi_index[dim], addend); 106 return Index(multi_index, index_type_); 107 } 108 multidim()109 const std::vector<llvm::Value*>& multidim() const { return multidim_; } linear()110 llvm::Value* linear() const { return linear_; } 111 size()112 size_t size() const { return multidim().size(); } 113 114 llvm::Value* operator[](size_t i) const { return multidim()[i]; } 115 116 using const_iterator = std::vector<llvm::Value*>::const_iterator; 117 begin()118 const_iterator begin() const { return multidim().begin(); } end()119 const_iterator end() const { return multidim().end(); } 120 121 bool LinearValidOnShape(const Shape& a) const; 122 123 // Given that "this" is the target index of a reshape from `operand_shape` 124 // to `shape`, returns the source index. 125 Index SourceIndexOfReshape(const Shape& output_shape, 126 const Shape& input_shape, 127 llvm::IRBuilder<>* builder) const; 128 129 // Returns the index into the source operand from which a slice operation 130 // selects a value to be placed into index "this". The slice is described 131 // by starting indices `starts` and stride values `strides`. 132 // 133 // Precondition: "this" is an index into a slice whose operand shape is 134 // `operand_shape`. 135 Index SourceIndexOfSlice(const Shape& operand_shape, 136 absl::Span<const int64> starts, 137 absl::Span<const int64> strides, 138 llvm::IRBuilder<>* builder) const; 139 140 // Given that "this" is the target index of a transpose from `operand_shape` 141 // to `shape` with the given dimension mapping, returns the source index. 142 Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, 143 absl::Span<const int64> dimension_mapping, 144 llvm::IRBuilder<>* builder) const; 145 146 // Given that "this" is the target index of a bitcast from `operand_shape` 147 // to `shape`, returns the source index. 148 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, 149 llvm::IRBuilder<>* builder) const; 150 151 // Given that "this" is the target index of a broadcast from `operand_shape` 152 // to `shape` with the given dimension mapping, returns the source index. 153 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, 154 absl::Span<const int64> dimension_mapping, 155 llvm::IRBuilder<>* builder) const; 156 157 // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and 158 // returns the index into the sole dimension 0 of the new shape. 159 llvm::Value* Linearize(absl::Span<const int64> dimensions, 160 llvm::IRBuilder<>* builder) const; 161 GetType()162 llvm::Type* GetType() const { return index_type_; } 163 GetConstantWithIndexType(int64 c)164 llvm::Constant* GetConstantWithIndexType(int64 c) const { 165 // The LLVM function makes sure that the value can be represented by the 166 // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const 167 // APInt &V). 168 return llvm::ConstantInt::get(index_type_, c); 169 } 170 171 private: 172 // Constructs an index from both a multi-dimensional index and a linear 173 // index. 'shape' is the shape on which the index is used. 'index_type' is 174 // the type of the index. 175 // 176 // Precondition: "shape" has a layout. 177 Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear, 178 const Shape& shape, llvm::Type* index_type); 179 180 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 181 const Shape& shape, llvm::IRBuilder<>* b) const; 182 183 std::vector<llvm::Value*> multidim_; 184 185 // These values are purely for efficiency; `multidim_` is enough to find the 186 // element at a given `Index`, but if a loop is emitted with a linear index 187 // space, that linear index can be saved in `linear_`, and the layout and 188 // dimensions of the shape the loop was emitted for in `layout_` and 189 // `dims_`, and if the `Index` is used in another array, and its layout and 190 // dimensions match, the linear index can be used, sparing the cost of 191 // computing `multidim_`, which LLVM DCE could potentially so delete. 192 // Modifying `multidim_` after construction nullifies `linear_`, lest it 193 // be used wrongly, as it would be valid no more. 194 // If a loop is emitted with a multidimensional index space, `linear_` would 195 // be null and `layout_` and `dims_` would be ignored. 196 llvm::Value* linear_ = nullptr; 197 Layout layout_; 198 std::vector<int64> dims_; 199 200 llvm::Type* index_type_; 201 }; 202 203 // Default constructor. Constructs an IrArray in a null status. IrArray()204 IrArray() : base_ptr_(nullptr) {} 205 206 // Construct an IrArray with the given base pointer and shape. base_ptr is a 207 // pointer type pointing to the first element(lowest address) of the array. 208 IrArray(llvm::Value* base_ptr, Shape shape); 209 210 // Default implementations of copying and moving. 211 IrArray(IrArray&& other) = default; 212 IrArray(const IrArray& other) = default; 213 IrArray& operator=(IrArray&& other) = default; 214 IrArray& operator=(const IrArray& other) = default; 215 GetBasePointer()216 llvm::Value* GetBasePointer() const { return base_ptr_; } GetElementLlvmType()217 llvm::Type* GetElementLlvmType() const { return element_type_; } 218 GetShape()219 const Shape& GetShape() const { return shape_; } 220 221 // Emit a sequence of instructions to compute the address of the element in 222 // the given array at the given index. Returns the address of the element as 223 // an LLVM Value. 224 // 225 // The optional name is useful for debugging when looking at 226 // the emitted LLVM IR. 227 llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, 228 absl::string_view name = "", 229 bool use_linear_index = true) const; 230 231 // Attach metadata this IrArray instance knows about to "instruction". 232 void AnnotateLoadStoreInstructionWithMetadata( 233 llvm::Instruction* instruction) const; 234 235 // Emit IR to read an array element at the given index. Returns the read 236 // result (effectively, a Value loaded from memory). This method seamlessly 237 // handles scalar shapes by broadcasting their value to all indices (index is 238 // ignored). 239 // 240 // The optional name is useful for debugging when looking at 241 // the emitted LLVM IR. 242 // 'use_linear_index' can be used to specify whether the linear index (if 243 // available) or the multi-dimensional index should be used. 244 llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, 245 absl::string_view name = "", 246 bool use_linear_index = true) const; 247 248 // Emit IR to write the given value to the array element at the given index. 249 // 'use_linear_index' can be used to specify whether the linear index (if 250 // available) or the multi-dimensional index should be used. 251 void EmitWriteArrayElement(const Index& index, llvm::Value* value, 252 llvm::IRBuilder<>* b, 253 bool use_linear_index = true) const; 254 255 // Returns a new IrArray whose shape is "new_shape" and base pointer is a 256 // bitcast of the base pointer of "this" IrArray. 257 // 'use_linear_index' can be used to specify whether the linear index (if 258 // available) or the multi-dimensional index should be used. 259 IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; 260 AddAliasScopeMetadata(llvm::MDNode * alias_scope)261 void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { 262 CHECK_NE(alias_scope, nullptr); 263 AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); 264 } 265 AddNoaliasMetadata(llvm::MDNode * noalias)266 void AddNoaliasMetadata(llvm::MDNode* noalias) { 267 CHECK_NE(noalias, nullptr); 268 AddMetadata(llvm::LLVMContext::MD_noalias, noalias); 269 } 270 271 // Promises LLVM that the data pointed to by this IrArray never changes after 272 // it's first loaded. 273 // 274 // The temporal scope of this promise is the "whole program" from LLVM's point 275 // of view, but how this translates to HLOs differs between backends. 276 // 277 // In the single-threaded CPU backend, we emit one function that 278 // runs all the HLOs in sequence, so the whole program is the whole HLO 279 // module. 280 // 281 // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO 282 // in the entry computation). From LLVM's perspective, launching a new kernel 283 // is like launching a new program, and so the whole program is one top-level 284 // HLO. Since the scope of the promise is smaller than in the CPU backend, we 285 // can mark more things as invariant in the GPU backend. 286 // 287 // Marking loads as invariant is particularly helpful on GPUs because 288 // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's 289 // __ldg intrinsic). These loads use a special cache, and can be 290 // significantly faster than regular loads. MarkInvariantOverWholeProgram(llvm::LLVMContext * context)291 void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { 292 if (is_invariant_) { 293 return; 294 } 295 is_invariant_ = true; 296 AddMetadata(llvm::LLVMContext::MD_invariant_load, 297 llvm::MDNode::get(*context, {})); 298 } 299 metadata()300 const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; } 301 302 private: 303 // Add the specified LLVM IR metadata to loads/stores associated with this 304 // IrArray. AddMetadata(int kind,llvm::MDNode * md)305 void AddMetadata(int kind, llvm::MDNode* md) { 306 InsertOrDie(&metadata_, kind, md); 307 } 308 309 // Address of the base of the array as an LLVM Value. 310 llvm::Value* base_ptr_; 311 312 // The LLVM type of the elements in the array. 313 llvm::Type* element_type_; 314 315 // Shape of the XLA array. 316 Shape shape_; 317 318 // The list of key/value pairs used when attaching metadata to emitted 319 // loads/stores for this array. They keys are the metadata kinds and the 320 // values are the metadata nodes. 321 std::map<int, llvm::MDNode*> metadata_; 322 323 bool is_invariant_ = false; 324 }; 325 326 } // namespace llvm_ir 327 } // namespace xla 328 329 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 330