1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/shape.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/logging.h" 33 34 namespace xla { 35 namespace llvm_ir { 36 37 // IrArray represents an XLA array at the LLVM IR level. This class 38 // encapsulates a base pointer to the buffer holding the array (as an LLVM 39 // Value) and the shape of the array. The class includes methods for emitting 40 // LLVM IR sequences which access elements of the array at a multidimensional 41 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts 42 // are supported. 43 class IrArray { 44 public: 45 // A multidimensional index into an IrArray. All the runtime indices 46 // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64_t>) 47 // are major-first. 48 // 49 // This may also keep a linear index and the layout and dimensions it was 50 // emitted for; if the shape where this `Index` is used matches, the linear 51 // index may be used, potentially sparing the cost of computing the 52 // multidimensional index, which LLVM DCE can delete. 53 class Index { 54 public: 55 // Constructs an index for a scalar shape. Index(llvm::Type * index_ty)56 explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { 57 CHECK(index_ty->isIntegerTy()); 58 } 59 60 // Constructs an index from linear index "linear" and computes the 61 // multi-dimensional index from "linear" and "shape". "b" is the IR 62 // builder to emit the index of each dimension in the multi-dimensional 63 // index. 64 // 65 // Precondition: "shape" has a layout. 66 Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); 67 68 // As before, but also take a multidim to reuse. multidim.size() 69 // == shape.rank() must be true. If some of the multidim element 70 // are null we will use the value that would be used if 71 // deliearized from linear. 72 Index(llvm::Value* linear, absl::Span<llvm::Value* const> multidim, 73 const Shape& shape, llvm::IRBuilder<>* b); 74 75 // Similar to the above constructor except using "dynamic_dims" instead of 76 // shape's static dimension to constructs the index. 77 Index(llvm::Value* linear, const Shape& shape, 78 absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b); 79 80 // Constructs an index from a multi-dimensional index. 'shape' is the shape 81 // for which the multi-dimensional index is used. 'index_type' is the type 82 // of the index. 83 // 84 // Precondition: "shape" has a layout. 85 Index(absl::Span<llvm::Value* const> multidim, const Shape& shape, 86 llvm::Type* index_type); 87 88 // Same as above, but only the dimensions of the shape without layout is 89 // passed. The layout is assumed to be the default (descending 90 // minor-to-major) layout. 91 Index(absl::Span<llvm::Value* const> multidim, 92 absl::Span<int64_t const> dimensions, llvm::Type* index_type); 93 94 // Returns an index that adds `addend` to the given `dim` of the object. AddOffsetToDim(llvm::Value * addend,int64_t dim,llvm::IRBuilder<> * b)95 Index AddOffsetToDim(llvm::Value* addend, int64_t dim, 96 llvm::IRBuilder<>* b) const { 97 Index with_offset = *this; 98 with_offset.linear_ = nullptr; 99 with_offset.multidim_[dim] = 100 b->CreateAdd(with_offset.multidim_[dim], addend); 101 return with_offset; 102 } 103 multidim()104 const std::vector<llvm::Value*>& multidim() const { return multidim_; } dims()105 const std::vector<int64_t>& dims() const { return dims_; } linear()106 llvm::Value* linear() const { return linear_; } 107 size()108 size_t size() const { return multidim().size(); } 109 110 llvm::Value* operator[](size_t i) const { return multidim()[i]; } 111 112 using const_iterator = std::vector<llvm::Value*>::const_iterator; 113 begin()114 const_iterator begin() const { return multidim().begin(); } end()115 const_iterator end() const { return multidim().end(); } 116 117 bool LinearValidOnShape(const Shape& a) const; 118 119 static bool ShapeIsCompatible(const Shape& a, const Shape& b); 120 ShapeIsCompatible(const Shape & a)121 bool ShapeIsCompatible(const Shape& a) const { 122 return ShapeIsCompatible(a, AsShapeWithType(a.element_type())); 123 } 124 AsShapeWithType(PrimitiveType element_type)125 Shape AsShapeWithType(PrimitiveType element_type) const { 126 return ShapeUtil::MakeShapeWithLayout(element_type, dims_, 127 layout_.minor_to_major()); 128 } 129 130 // Given that "this" is the target index of a reshape from `input_shape` 131 // to `output_shape`, returns the source index. 132 Index SourceIndexOfReshape(const Shape& output_shape, 133 const Shape& input_shape, 134 llvm::IRBuilder<>* builder) const; 135 136 // Returns the index into the source operand from which a slice operation 137 // selects a value to be placed into index "this". The slice is described 138 // by starting indices `starts` and stride values `strides`. 139 // 140 // Precondition: "this" is an index into a slice whose operand shape is 141 // `operand_shape`. 142 Index SourceIndexOfSlice(const Shape& operand_shape, 143 absl::Span<const int64_t> starts, 144 absl::Span<const int64_t> strides, 145 llvm::IRBuilder<>* builder) const; 146 147 // Given that "this" is the target index of a transpose from `operand_shape` 148 // to `shape` with the given dimension mapping, returns the source index. 149 Index SourceIndexOfTranspose( 150 const Shape& shape, const Shape& operand_shape, 151 absl::Span<const int64_t> dimension_mapping) const; 152 153 // Given that "this" is the target index of a bitcast from `operand_shape` 154 // to `shape`, returns the source index. 155 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, 156 llvm::IRBuilder<>* builder) const; 157 158 // Given that "this" is the target index of a broadcast from `operand_shape` 159 // to `shape` with the given dimension mapping, returns the source index. 160 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, 161 absl::Span<const int64_t> dimension_mapping, 162 llvm::IRBuilder<>* builder) const; 163 164 // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and 165 // returns the index into the sole dimension 0 of the new shape. 166 llvm::Value* Linearize(absl::Span<const int64_t> dimensions, 167 llvm::IRBuilder<>* builder) const; 168 169 // Linearizes the index into the given dynamic dimensions. 170 llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims, 171 llvm::IRBuilder<>* builder) const; 172 GetType()173 llvm::Type* GetType() const { return index_type_; } 174 GetConstantWithIndexType(int64_t c)175 llvm::Constant* GetConstantWithIndexType(int64_t c) const { 176 // The LLVM function makes sure that the value can be represented by the 177 // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const 178 // APInt &V). 179 return llvm::ConstantInt::get(index_type_, c); 180 } 181 182 private: 183 // Constructs an index from both a multi-dimensional index and a linear 184 // index. 'shape' is the shape on which the index is used. 'index_type' is 185 // the type of the index. 186 // 187 // Precondition: "shape" has a layout. 188 Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear, 189 const Shape& shape, llvm::Type* index_type); 190 191 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 192 const Shape& shape, llvm::IRBuilder<>* b) const; 193 194 // Delinearize the linear index with the dynamic dimensions. 195 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 196 const Shape& shape, absl::Span<llvm::Value*> dynamic_dims, 197 llvm::IRBuilder<>* b) const; 198 199 std::vector<llvm::Value*> multidim_; 200 201 // These values are purely for efficiency; `multidim_` is enough to find the 202 // element at a given `Index`, but if a loop is emitted with a linear index 203 // space, that linear index can be saved in `linear_`, and the layout and 204 // dimensions of the shape the loop was emitted for in `layout_` and 205 // `dims_`, and if the `Index` is used in another array, and its layout and 206 // dimensions match, the linear index can be used, sparing the cost of 207 // computing `multidim_`, which LLVM DCE could potentially so delete. 208 // Modifying `multidim_` after construction nullifies `linear_`, lest it 209 // be used wrongly, as it would be valid no more. 210 // If a loop is emitted with a multidimensional index space, `linear_` would 211 // be null and `layout_` and `dims_` would be ignored. 212 llvm::Value* linear_ = nullptr; 213 Layout layout_; 214 std::vector<int64_t> dims_; 215 216 llvm::Type* index_type_; 217 }; 218 219 // Default constructor. Constructs an IrArray in a null status. IrArray()220 IrArray() : base_ptr_(nullptr) {} 221 222 // Construct an IrArray with the given base pointer, pointee type, and shape. 223 // base_ptr is a pointer type pointing to the first element(lowest address) 224 // of the array. 225 IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape); 226 227 // Default implementations of copying and moving. 228 IrArray(IrArray&& other) = default; 229 IrArray(const IrArray& other) = default; 230 IrArray& operator=(IrArray&& other) = default; 231 IrArray& operator=(const IrArray& other) = default; 232 GetBasePointer()233 llvm::Value* GetBasePointer() const { return base_ptr_; } GetBasePointeeType()234 llvm::Type* GetBasePointeeType() const { return pointee_type_; } GetElementLlvmType()235 llvm::Type* GetElementLlvmType() const { return element_type_; } 236 GetShape()237 const Shape& GetShape() const { return shape_; } 238 239 // Emit a sequence of instructions to compute the address of the element in 240 // the given array at the given index. Returns the address of the element as 241 // an LLVM Value. 242 // 243 // The optional name is useful for debugging when looking at 244 // the emitted LLVM IR. 245 llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, 246 absl::string_view name = "", 247 bool use_linear_index = true) const; 248 249 // Attach metadata this IrArray instance knows about to "instruction". 250 void AnnotateLoadStoreInstructionWithMetadata( 251 llvm::Instruction* instruction) const; 252 253 // Emit IR to read an array element at the given index. Returns the read 254 // result (effectively, a Value loaded from memory). This method seamlessly 255 // handles scalar shapes by broadcasting their value to all indices (index is 256 // ignored). 257 // 258 // The optional name is useful for debugging when looking at 259 // the emitted LLVM IR. 260 // 'use_linear_index' can be used to specify whether the linear index (if 261 // available) or the multi-dimensional index should be used. 262 llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, 263 absl::string_view name = "", 264 bool use_linear_index = true) const; 265 266 // Emit IR to write the given value to the array element at the given index. 267 // 'use_linear_index' can be used to specify whether the linear index (if 268 // available) or the multi-dimensional index should be used. 269 void EmitWriteArrayElement(const Index& index, llvm::Value* value, 270 llvm::IRBuilder<>* b, 271 bool use_linear_index = true) const; 272 273 // Returns a new IrArray whose shape is "new_shape" and base pointer is a 274 // bitcast of the base pointer of "this" IrArray. 275 // 'use_linear_index' can be used to specify whether the linear index (if 276 // available) or the multi-dimensional index should be used. 277 IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; 278 AddAliasScopeMetadata(llvm::MDNode * alias_scope)279 void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { 280 CHECK_NE(alias_scope, nullptr); 281 AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); 282 } 283 AddNoaliasMetadata(llvm::MDNode * noalias)284 void AddNoaliasMetadata(llvm::MDNode* noalias) { 285 CHECK_NE(noalias, nullptr); 286 AddMetadata(llvm::LLVMContext::MD_noalias, noalias); 287 } 288 289 // Promises LLVM that the data pointed to by this IrArray never changes after 290 // it's first loaded. 291 // 292 // The temporal scope of this promise is the "whole program" from LLVM's point 293 // of view, but how this translates to HLOs differs between backends. 294 // 295 // In the single-threaded CPU backend, we emit one function that 296 // runs all the HLOs in sequence, so the whole program is the whole HLO 297 // module. 298 // 299 // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO 300 // in the entry computation). From LLVM's perspective, launching a new kernel 301 // is like launching a new program, and so the whole program is one top-level 302 // HLO. Since the scope of the promise is smaller than in the CPU backend, we 303 // can mark more things as invariant in the GPU backend. 304 // 305 // Marking loads as invariant is particularly helpful on GPUs because 306 // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's 307 // __ldg intrinsic). These loads use a special cache, and can be 308 // significantly faster than regular loads. MarkInvariantOverWholeProgram(llvm::LLVMContext * context)309 void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { 310 if (is_invariant_) { 311 return; 312 } 313 is_invariant_ = true; 314 AddMetadata(llvm::LLVMContext::MD_invariant_load, 315 llvm::MDNode::get(*context, {})); 316 } 317 metadata()318 const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; } 319 320 private: 321 // Add the specified LLVM IR metadata to loads/stores associated with this 322 // IrArray. AddMetadata(int kind,llvm::MDNode * md)323 void AddMetadata(int kind, llvm::MDNode* md) { 324 InsertOrDie(&metadata_, kind, md); 325 } 326 327 // Address of the base of the array as an LLVM Value. 328 llvm::Value* base_ptr_; 329 330 // The pointee type of base_ptr_; 331 llvm::Type* pointee_type_; 332 333 // The LLVM type of the elements in the array. 334 llvm::Type* element_type_; 335 336 // Shape of the XLA array. 337 Shape shape_; 338 339 // The list of key/value pairs used when attaching metadata to emitted 340 // loads/stores for this array. They keys are the metadata kinds and the 341 // values are the metadata nodes. 342 std::map<int, llvm::MDNode*> metadata_; 343 344 bool is_invariant_ = false; 345 }; 346 347 } // namespace llvm_ir 348 } // namespace xla 349 350 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 351