1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/shape.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 namespace llvm_ir { 37 38 // IrArray represents an XLA array at the LLVM IR level. This class 39 // encapsulates a base pointer to the buffer holding the array (as an LLVM 40 // Value) and the shape of the array. The class includes methods for emitting 41 // LLVM IR sequences which access elements of the array at a multidimensional 42 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts 43 // are supported. 44 class IrArray { 45 public: 46 // A multidimensional index into an IrArray. All the runtime indices 47 // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64>) 48 // are major-first. 49 // 50 // This may also keep a linear index and the layout and dimensions it was 51 // emitted for; if the shape where this `Index` is used matches, the linear 52 // index may be used, potentially sparing the cost of computing the 53 // multidimensional index, which LLVM DCE can delete. 54 class Index { 55 public: 56 // Constructs an index for a scalar shape. Index(llvm::Type * index_ty)57 explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { 58 CHECK(index_ty->isIntegerTy()); 59 } 60 61 // Constructs an index from linear index "linear" and computes the 62 // multi-dimensional index from "linear" and "shape". "b" is the IR 63 // builder to emit the index of each dimension in the multi-dimensional 64 // index. 65 // 66 // Precondition: "shape" has a layout. 67 Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); 68 69 // As before, but also take a multidim to reuse. multidim.size() 70 // == shape.rank() must be true. If some of the multidim element 71 // are null we will use the value that would be used if 72 // deliearized from linear. 73 Index(llvm::Value* linear, absl::Span<llvm::Value* const> multidim, 74 const Shape& shape, llvm::IRBuilder<>* b); 75 76 // Similar to the above constructor except using "dynamic_dims" instead of 77 // shape's static dimension to constructs the index. 78 Index(llvm::Value* linear, const Shape& shape, 79 absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b); 80 81 // Constructs an index from a multi-dimensional index. 'shape' is the shape 82 // for which the multi-dimensional index is used. 'index_type' is the type 83 // of the index. 84 // 85 // Precondition: "shape" has a layout. 86 Index(absl::Span<llvm::Value* const> multidim, const Shape& shape, 87 llvm::Type* index_type); 88 89 // Same as above, but only the dimensions of the shape without layout is 90 // passed. The layout is assumed to be the default (descending 91 // minor-to-major) layout. 92 Index(absl::Span<llvm::Value* const> multidim, 93 absl::Span<int64 const> dimensions, llvm::Type* index_type); 94 95 // Returns an index that adds `addend` to the given `dim` of the object. AddOffsetToDim(llvm::Value * addend,int64_t dim,llvm::IRBuilder<> * b)96 Index AddOffsetToDim(llvm::Value* addend, int64_t dim, 97 llvm::IRBuilder<>* b) const { 98 Index with_offset = *this; 99 with_offset.linear_ = nullptr; 100 with_offset.multidim_[dim] = 101 b->CreateAdd(with_offset.multidim_[dim], addend); 102 return with_offset; 103 } 104 multidim()105 const std::vector<llvm::Value*>& multidim() const { return multidim_; } dims()106 const std::vector<int64>& dims() const { return dims_; } linear()107 llvm::Value* linear() const { return linear_; } 108 size()109 size_t size() const { return multidim().size(); } 110 111 llvm::Value* operator[](size_t i) const { return multidim()[i]; } 112 113 using const_iterator = std::vector<llvm::Value*>::const_iterator; 114 begin()115 const_iterator begin() const { return multidim().begin(); } end()116 const_iterator end() const { return multidim().end(); } 117 118 bool LinearValidOnShape(const Shape& a) const; 119 120 static bool ShapeIsCompatible(const Shape& a, const Shape& b); 121 ShapeIsCompatible(const Shape & a)122 bool ShapeIsCompatible(const Shape& a) const { 123 return ShapeIsCompatible( 124 a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_, 125 layout_.minor_to_major())); 126 } 127 128 // Given that "this" is the target index of a reshape from `input_shape` 129 // to `output_shape`, returns the source index. 130 Index SourceIndexOfReshape(const Shape& output_shape, 131 const Shape& input_shape, 132 llvm::IRBuilder<>* builder) const; 133 134 // Returns the index into the source operand from which a slice operation 135 // selects a value to be placed into index "this". The slice is described 136 // by starting indices `starts` and stride values `strides`. 137 // 138 // Precondition: "this" is an index into a slice whose operand shape is 139 // `operand_shape`. 140 Index SourceIndexOfSlice(const Shape& operand_shape, 141 absl::Span<const int64> starts, 142 absl::Span<const int64> strides, 143 llvm::IRBuilder<>* builder) const; 144 145 // Given that "this" is the target index of a transpose from `operand_shape` 146 // to `shape` with the given dimension mapping, returns the source index. 147 Index SourceIndexOfTranspose( 148 const Shape& shape, const Shape& operand_shape, 149 absl::Span<const int64> dimension_mapping) const; 150 151 // Given that "this" is the target index of a bitcast from `operand_shape` 152 // to `shape`, returns the source index. 153 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, 154 llvm::IRBuilder<>* builder) const; 155 156 // Given that "this" is the target index of a broadcast from `operand_shape` 157 // to `shape` with the given dimension mapping, returns the source index. 158 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, 159 absl::Span<const int64> dimension_mapping, 160 llvm::IRBuilder<>* builder) const; 161 162 // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and 163 // returns the index into the sole dimension 0 of the new shape. 164 llvm::Value* Linearize(absl::Span<const int64> dimensions, 165 llvm::IRBuilder<>* builder) const; 166 167 // Linearizes the index into the given dynamic dimensions. 168 llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims, 169 llvm::IRBuilder<>* builder) const; 170 GetType()171 llvm::Type* GetType() const { return index_type_; } 172 GetConstantWithIndexType(int64_t c)173 llvm::Constant* GetConstantWithIndexType(int64_t c) const { 174 // The LLVM function makes sure that the value can be represented by the 175 // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const 176 // APInt &V). 177 return llvm::ConstantInt::get(index_type_, c); 178 } 179 180 private: 181 // Constructs an index from both a multi-dimensional index and a linear 182 // index. 'shape' is the shape on which the index is used. 'index_type' is 183 // the type of the index. 184 // 185 // Precondition: "shape" has a layout. 186 Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear, 187 const Shape& shape, llvm::Type* index_type); 188 189 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 190 const Shape& shape, llvm::IRBuilder<>* b) const; 191 192 // Delinearize the linear index with the dynamic dimensions. 193 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 194 const Shape& shape, absl::Span<llvm::Value*> dynamic_dims, 195 llvm::IRBuilder<>* b) const; 196 197 std::vector<llvm::Value*> multidim_; 198 199 // These values are purely for efficiency; `multidim_` is enough to find the 200 // element at a given `Index`, but if a loop is emitted with a linear index 201 // space, that linear index can be saved in `linear_`, and the layout and 202 // dimensions of the shape the loop was emitted for in `layout_` and 203 // `dims_`, and if the `Index` is used in another array, and its layout and 204 // dimensions match, the linear index can be used, sparing the cost of 205 // computing `multidim_`, which LLVM DCE could potentially so delete. 206 // Modifying `multidim_` after construction nullifies `linear_`, lest it 207 // be used wrongly, as it would be valid no more. 208 // If a loop is emitted with a multidimensional index space, `linear_` would 209 // be null and `layout_` and `dims_` would be ignored. 210 llvm::Value* linear_ = nullptr; 211 Layout layout_; 212 std::vector<int64> dims_; 213 214 llvm::Type* index_type_; 215 }; 216 217 // Default constructor. Constructs an IrArray in a null status. IrArray()218 IrArray() : base_ptr_(nullptr) {} 219 220 // Construct an IrArray with the given base pointer and shape. base_ptr is a 221 // pointer type pointing to the first element(lowest address) of the array. 222 IrArray(llvm::Value* base_ptr, Shape shape); 223 224 // Default implementations of copying and moving. 225 IrArray(IrArray&& other) = default; 226 IrArray(const IrArray& other) = default; 227 IrArray& operator=(IrArray&& other) = default; 228 IrArray& operator=(const IrArray& other) = default; 229 GetBasePointer()230 llvm::Value* GetBasePointer() const { return base_ptr_; } GetElementLlvmType()231 llvm::Type* GetElementLlvmType() const { return element_type_; } 232 GetShape()233 const Shape& GetShape() const { return shape_; } 234 235 // Emit a sequence of instructions to compute the address of the element in 236 // the given array at the given index. Returns the address of the element as 237 // an LLVM Value. 238 // 239 // The optional name is useful for debugging when looking at 240 // the emitted LLVM IR. 241 llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, 242 absl::string_view name = "", 243 bool use_linear_index = true) const; 244 245 // Attach metadata this IrArray instance knows about to "instruction". 246 void AnnotateLoadStoreInstructionWithMetadata( 247 llvm::Instruction* instruction) const; 248 249 // Emit IR to read an array element at the given index. Returns the read 250 // result (effectively, a Value loaded from memory). This method seamlessly 251 // handles scalar shapes by broadcasting their value to all indices (index is 252 // ignored). 253 // 254 // The optional name is useful for debugging when looking at 255 // the emitted LLVM IR. 256 // 'use_linear_index' can be used to specify whether the linear index (if 257 // available) or the multi-dimensional index should be used. 258 llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, 259 absl::string_view name = "", 260 bool use_linear_index = true) const; 261 262 // Emit IR to write the given value to the array element at the given index. 263 // 'use_linear_index' can be used to specify whether the linear index (if 264 // available) or the multi-dimensional index should be used. 265 void EmitWriteArrayElement(const Index& index, llvm::Value* value, 266 llvm::IRBuilder<>* b, 267 bool use_linear_index = true) const; 268 269 // Returns a new IrArray whose shape is "new_shape" and base pointer is a 270 // bitcast of the base pointer of "this" IrArray. 271 // 'use_linear_index' can be used to specify whether the linear index (if 272 // available) or the multi-dimensional index should be used. 273 IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; 274 AddAliasScopeMetadata(llvm::MDNode * alias_scope)275 void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { 276 CHECK_NE(alias_scope, nullptr); 277 AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); 278 } 279 AddNoaliasMetadata(llvm::MDNode * noalias)280 void AddNoaliasMetadata(llvm::MDNode* noalias) { 281 CHECK_NE(noalias, nullptr); 282 AddMetadata(llvm::LLVMContext::MD_noalias, noalias); 283 } 284 285 // Promises LLVM that the data pointed to by this IrArray never changes after 286 // it's first loaded. 287 // 288 // The temporal scope of this promise is the "whole program" from LLVM's point 289 // of view, but how this translates to HLOs differs between backends. 290 // 291 // In the single-threaded CPU backend, we emit one function that 292 // runs all the HLOs in sequence, so the whole program is the whole HLO 293 // module. 294 // 295 // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO 296 // in the entry computation). From LLVM's perspective, launching a new kernel 297 // is like launching a new program, and so the whole program is one top-level 298 // HLO. Since the scope of the promise is smaller than in the CPU backend, we 299 // can mark more things as invariant in the GPU backend. 300 // 301 // Marking loads as invariant is particularly helpful on GPUs because 302 // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's 303 // __ldg intrinsic). These loads use a special cache, and can be 304 // significantly faster than regular loads. MarkInvariantOverWholeProgram(llvm::LLVMContext * context)305 void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { 306 if (is_invariant_) { 307 return; 308 } 309 is_invariant_ = true; 310 AddMetadata(llvm::LLVMContext::MD_invariant_load, 311 llvm::MDNode::get(*context, {})); 312 } 313 metadata()314 const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; } 315 316 private: 317 // Add the specified LLVM IR metadata to loads/stores associated with this 318 // IrArray. AddMetadata(int kind,llvm::MDNode * md)319 void AddMetadata(int kind, llvm::MDNode* md) { 320 InsertOrDie(&metadata_, kind, md); 321 } 322 323 // Address of the base of the array as an LLVM Value. 324 llvm::Value* base_ptr_; 325 326 // The LLVM type of the elements in the array. 327 llvm::Type* element_type_; 328 329 // Shape of the XLA array. 330 Shape shape_; 331 332 // The list of key/value pairs used when attaching metadata to emitted 333 // loads/stores for this array. They keys are the metadata kinds and the 334 // values are the metadata nodes. 335 std::map<int, llvm::MDNode*> metadata_; 336 337 bool is_invariant_ = false; 338 }; 339 340 } // namespace llvm_ir 341 } // namespace xla 342 343 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 344