1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/shape.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 namespace llvm_ir { 37 38 // IrArray represents an XLA array at the LLVM IR level. This class 39 // encapsulates a base pointer to the buffer holding the array (as an LLVM 40 // Value) and the shape of the array. The class includes methods for emitting 41 // LLVM IR sequences which access elements of the array at a multidimensional 42 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts 43 // are supported. 44 class IrArray { 45 public: 46 // A multidimensional index into an IrArray. All the runtime indices 47 // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64>) 48 // are major-first. 49 // 50 // This may also keep a linear index and the layout and dimensions it was 51 // emitted for; if the shape where this `Index` is used matches, the linear 52 // index may be used, potentially sparing the cost of computing the 53 // multidimensional index, which LLVM DCE can delete. 54 class Index { 55 public: 56 // Constructs an index for a scalar shape. Index(llvm::Type * index_ty)57 explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { 58 CHECK(index_ty->isIntegerTy()); 59 } 60 61 // Constructs an index from linear index "linear" and computes the 62 // multi-dimensional index from "linear" and "shape". "b" is the IR 63 // builder to emit the index of each dimension in the multi-dimensional 64 // index. 65 // 66 // Precondition: "shape" has a layout. 67 Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); 68 69 // Similar to the above constructor except using "dynamic_dims" instead of 70 // shape's static dimension to constructs the index. 71 Index(llvm::Value* linear, const Shape& shape, 72 absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b); 73 74 // Constructs an index from a multi-dimensional index. 'shape' is the shape 75 // for which the multi-dimensional index is used. 'index_type' is the type 76 // of the index. 77 // 78 // Precondition: "shape" has a layout. 79 Index(absl::Span<llvm::Value* const> multidim, const Shape& shape, 80 llvm::Type* index_type); 81 82 // Same as above, but only the dimensions of the shape without layout is 83 // passed. The layout is assumed to be the default (descending 84 // minor-to-major) layout. 85 Index(absl::Span<llvm::Value* const> multidim, 86 absl::Span<int64 const> dimensions, llvm::Type* index_type); 87 88 // Returns an index that adds `addend` to the given `dim` of the object. AddOffsetToDim(llvm::Value * addend,int64 dim,llvm::IRBuilder<> * b)89 Index AddOffsetToDim(llvm::Value* addend, int64 dim, 90 llvm::IRBuilder<>* b) const { 91 Index with_offset = *this; 92 with_offset.linear_ = nullptr; 93 with_offset.multidim_[dim] = 94 b->CreateAdd(with_offset.multidim_[dim], addend); 95 return with_offset; 96 } 97 multidim()98 const std::vector<llvm::Value*>& multidim() const { return multidim_; } dims()99 const std::vector<int64>& dims() const { return dims_; } linear()100 llvm::Value* linear() const { return linear_; } 101 size()102 size_t size() const { return multidim().size(); } 103 104 llvm::Value* operator[](size_t i) const { return multidim()[i]; } 105 106 using const_iterator = std::vector<llvm::Value*>::const_iterator; 107 begin()108 const_iterator begin() const { return multidim().begin(); } end()109 const_iterator end() const { return multidim().end(); } 110 111 bool LinearValidOnShape(const Shape& a) const; 112 113 static bool ShapeIsCompatible(const Shape& a, const Shape& b); 114 ShapeIsCompatible(const Shape & a)115 bool ShapeIsCompatible(const Shape& a) const { 116 return ShapeIsCompatible( 117 a, ShapeUtil::MakeShapeWithLayout(a.element_type(), dims_, 118 layout_.minor_to_major())); 119 } 120 121 // Given that "this" is the target index of a reshape from `input_shape` 122 // to `output_shape`, returns the source index. 123 Index SourceIndexOfReshape(const Shape& output_shape, 124 const Shape& input_shape, 125 llvm::IRBuilder<>* builder) const; 126 127 // Returns the index into the source operand from which a slice operation 128 // selects a value to be placed into index "this". The slice is described 129 // by starting indices `starts` and stride values `strides`. 130 // 131 // Precondition: "this" is an index into a slice whose operand shape is 132 // `operand_shape`. 133 Index SourceIndexOfSlice(const Shape& operand_shape, 134 absl::Span<const int64> starts, 135 absl::Span<const int64> strides, 136 llvm::IRBuilder<>* builder) const; 137 138 // Given that "this" is the target index of a transpose from `operand_shape` 139 // to `shape` with the given dimension mapping, returns the source index. 140 Index SourceIndexOfTranspose( 141 const Shape& shape, const Shape& operand_shape, 142 absl::Span<const int64> dimension_mapping) const; 143 144 // Given that "this" is the target index of a bitcast from `operand_shape` 145 // to `shape`, returns the source index. 146 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, 147 llvm::IRBuilder<>* builder) const; 148 149 // Given that "this" is the target index of a broadcast from `operand_shape` 150 // to `shape` with the given dimension mapping, returns the source index. 151 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, 152 absl::Span<const int64> dimension_mapping, 153 llvm::IRBuilder<>* builder) const; 154 155 // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and 156 // returns the index into the sole dimension 0 of the new shape. 157 llvm::Value* Linearize(absl::Span<const int64> dimensions, 158 llvm::IRBuilder<>* builder) const; 159 160 // Linearizes the index into the given dynamic dimensions. 161 llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims, 162 llvm::IRBuilder<>* builder) const; 163 GetType()164 llvm::Type* GetType() const { return index_type_; } 165 GetConstantWithIndexType(int64 c)166 llvm::Constant* GetConstantWithIndexType(int64 c) const { 167 // The LLVM function makes sure that the value can be represented by the 168 // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const 169 // APInt &V). 170 return llvm::ConstantInt::get(index_type_, c); 171 } 172 173 private: 174 // Constructs an index from both a multi-dimensional index and a linear 175 // index. 'shape' is the shape on which the index is used. 'index_type' is 176 // the type of the index. 177 // 178 // Precondition: "shape" has a layout. 179 Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear, 180 const Shape& shape, llvm::Type* index_type); 181 182 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 183 const Shape& shape, llvm::IRBuilder<>* b) const; 184 185 // Delinearize the linear index with the dynamic dimensions. 186 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 187 const Shape& shape, absl::Span<llvm::Value*> dynamic_dims, 188 llvm::IRBuilder<>* b) const; 189 190 std::vector<llvm::Value*> multidim_; 191 192 // These values are purely for efficiency; `multidim_` is enough to find the 193 // element at a given `Index`, but if a loop is emitted with a linear index 194 // space, that linear index can be saved in `linear_`, and the layout and 195 // dimensions of the shape the loop was emitted for in `layout_` and 196 // `dims_`, and if the `Index` is used in another array, and its layout and 197 // dimensions match, the linear index can be used, sparing the cost of 198 // computing `multidim_`, which LLVM DCE could potentially so delete. 199 // Modifying `multidim_` after construction nullifies `linear_`, lest it 200 // be used wrongly, as it would be valid no more. 201 // If a loop is emitted with a multidimensional index space, `linear_` would 202 // be null and `layout_` and `dims_` would be ignored. 203 llvm::Value* linear_ = nullptr; 204 Layout layout_; 205 std::vector<int64> dims_; 206 207 llvm::Type* index_type_; 208 }; 209 210 // Default constructor. Constructs an IrArray in a null status. IrArray()211 IrArray() : base_ptr_(nullptr) {} 212 213 // Construct an IrArray with the given base pointer and shape. base_ptr is a 214 // pointer type pointing to the first element(lowest address) of the array. 215 IrArray(llvm::Value* base_ptr, Shape shape); 216 217 // Default implementations of copying and moving. 218 IrArray(IrArray&& other) = default; 219 IrArray(const IrArray& other) = default; 220 IrArray& operator=(IrArray&& other) = default; 221 IrArray& operator=(const IrArray& other) = default; 222 GetBasePointer()223 llvm::Value* GetBasePointer() const { return base_ptr_; } GetElementLlvmType()224 llvm::Type* GetElementLlvmType() const { return element_type_; } 225 GetShape()226 const Shape& GetShape() const { return shape_; } 227 228 // Emit a sequence of instructions to compute the address of the element in 229 // the given array at the given index. Returns the address of the element as 230 // an LLVM Value. 231 // 232 // The optional name is useful for debugging when looking at 233 // the emitted LLVM IR. 234 llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, 235 absl::string_view name = "", 236 bool use_linear_index = true) const; 237 238 // Attach metadata this IrArray instance knows about to "instruction". 239 void AnnotateLoadStoreInstructionWithMetadata( 240 llvm::Instruction* instruction) const; 241 242 // Emit IR to read an array element at the given index. Returns the read 243 // result (effectively, a Value loaded from memory). This method seamlessly 244 // handles scalar shapes by broadcasting their value to all indices (index is 245 // ignored). 246 // 247 // The optional name is useful for debugging when looking at 248 // the emitted LLVM IR. 249 // 'use_linear_index' can be used to specify whether the linear index (if 250 // available) or the multi-dimensional index should be used. 251 llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, 252 absl::string_view name = "", 253 bool use_linear_index = true) const; 254 255 // Emit IR to write the given value to the array element at the given index. 256 // 'use_linear_index' can be used to specify whether the linear index (if 257 // available) or the multi-dimensional index should be used. 258 void EmitWriteArrayElement(const Index& index, llvm::Value* value, 259 llvm::IRBuilder<>* b, 260 bool use_linear_index = true) const; 261 262 // Returns a new IrArray whose shape is "new_shape" and base pointer is a 263 // bitcast of the base pointer of "this" IrArray. 264 // 'use_linear_index' can be used to specify whether the linear index (if 265 // available) or the multi-dimensional index should be used. 266 IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; 267 AddAliasScopeMetadata(llvm::MDNode * alias_scope)268 void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { 269 CHECK_NE(alias_scope, nullptr); 270 AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); 271 } 272 AddNoaliasMetadata(llvm::MDNode * noalias)273 void AddNoaliasMetadata(llvm::MDNode* noalias) { 274 CHECK_NE(noalias, nullptr); 275 AddMetadata(llvm::LLVMContext::MD_noalias, noalias); 276 } 277 278 // Promises LLVM that the data pointed to by this IrArray never changes after 279 // it's first loaded. 280 // 281 // The temporal scope of this promise is the "whole program" from LLVM's point 282 // of view, but how this translates to HLOs differs between backends. 283 // 284 // In the single-threaded CPU backend, we emit one function that 285 // runs all the HLOs in sequence, so the whole program is the whole HLO 286 // module. 287 // 288 // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO 289 // in the entry computation). From LLVM's perspective, launching a new kernel 290 // is like launching a new program, and so the whole program is one top-level 291 // HLO. Since the scope of the promise is smaller than in the CPU backend, we 292 // can mark more things as invariant in the GPU backend. 293 // 294 // Marking loads as invariant is particularly helpful on GPUs because 295 // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's 296 // __ldg intrinsic). These loads use a special cache, and can be 297 // significantly faster than regular loads. MarkInvariantOverWholeProgram(llvm::LLVMContext * context)298 void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { 299 if (is_invariant_) { 300 return; 301 } 302 is_invariant_ = true; 303 AddMetadata(llvm::LLVMContext::MD_invariant_load, 304 llvm::MDNode::get(*context, {})); 305 } 306 metadata()307 const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; } 308 309 private: 310 // Add the specified LLVM IR metadata to loads/stores associated with this 311 // IrArray. AddMetadata(int kind,llvm::MDNode * md)312 void AddMetadata(int kind, llvm::MDNode* md) { 313 InsertOrDie(&metadata_, kind, md); 314 } 315 316 // Address of the base of the array as an LLVM Value. 317 llvm::Value* base_ptr_; 318 319 // The LLVM type of the elements in the array. 320 llvm::Type* element_type_; 321 322 // Shape of the XLA array. 323 Shape shape_; 324 325 // The list of key/value pairs used when attaching metadata to emitted 326 // loads/stores for this array. They keys are the metadata kinds and the 327 // values are the metadata nodes. 328 std::map<int, llvm::MDNode*> metadata_; 329 330 bool is_invariant_ = false; 331 }; 332 333 } // namespace llvm_ir 334 } // namespace xla 335 336 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 337