1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Value.h" 27 #include "tensorflow/compiler/xla/map_util.h" 28 #include "tensorflow/compiler/xla/shape.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace xla { 36 namespace llvm_ir { 37 38 // IrArray represents an XLA array at the LLVM IR level. This class 39 // encapsulates a base pointer to the buffer holding the array (as an LLVM 40 // Value) and the shape of the array. The class includes methods for emitting 41 // LLVM IR sequences which access elements of the array at a multidimensional 42 // index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts 43 // are supported. 44 class IrArray { 45 public: 46 // A multidimensional index into an IrArray. All the runtime indices 47 // (multidim) and dimensions (Shape::dimensions(), absl::Span<const int64>) 48 // are major-first. 49 // 50 // This may also keep a linear index and the layout and dimensions it was 51 // emitted for; if the shape where this `Index` is used matches, the linear 52 // index may be used, potentially sparing the cost of computing the 53 // multidimensional index, which LLVM DCE can delete. 54 class Index { 55 public: 56 // Constructs an index for a scalar shape. Index(llvm::Type * index_ty)57 explicit Index(llvm::Type* index_ty) : index_type_(index_ty) { 58 CHECK(index_ty->isIntegerTy()); 59 } 60 61 // Constructs an index from linear index "linear" and computes the 62 // multi-dimensional index from "linear" and "shape". "b" is the IR 63 // builder to emit the index of each dimension in the multi-dimensional 64 // index. 65 // 66 // Precondition: "shape" has a layout. 67 Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); 68 69 // Constructs an index from a multi-dimensional index. 'shape' is the shape 70 // for which the multi-dimensional index is used. 'index_type' is the type 71 // of the index. 72 // 73 // Precondition: "shape" has a layout. 74 Index(absl::Span<llvm::Value* const> multidim, const Shape& shape, 75 llvm::Type* index_type); 76 77 // Same as above, but only the dimensions of the shape without layout is 78 // passed. The layout is assumed to be the default (descending 79 // minor-to-major) layout. 80 Index(absl::Span<llvm::Value* const> multidim, 81 absl::Span<int64 const> dimensions, llvm::Type* index_type); 82 83 // Returns an index that adds `addend` to the given `dim` of the object. AddOffsetToDim(llvm::Value * addend,int64 dim,llvm::IRBuilder<> * b)84 Index AddOffsetToDim(llvm::Value* addend, int64 dim, 85 llvm::IRBuilder<>* b) const { 86 Index with_offset = *this; 87 with_offset.linear_ = nullptr; 88 with_offset.multidim_[dim] = 89 b->CreateAdd(with_offset.multidim_[dim], addend); 90 return with_offset; 91 } 92 multidim()93 const std::vector<llvm::Value*>& multidim() const { return multidim_; } dims()94 const std::vector<int64>& dims() const { return dims_; } linear()95 llvm::Value* linear() const { return linear_; } 96 size()97 size_t size() const { return multidim().size(); } 98 99 llvm::Value* operator[](size_t i) const { return multidim()[i]; } 100 101 using const_iterator = std::vector<llvm::Value*>::const_iterator; 102 begin()103 const_iterator begin() const { return multidim().begin(); } end()104 const_iterator end() const { return multidim().end(); } 105 106 bool LinearValidOnShape(const Shape& a) const; 107 ShapeIsCompatible(const Shape & a)108 bool ShapeIsCompatible(const Shape& a) const { 109 Shape own_shape = ShapeUtil::MakeShape(a.element_type(), dims_); 110 *own_shape.mutable_layout() = layout_; 111 // The shape 'a' could have dynamic dimensions set. Before we check for 112 // equality, we need to copy the information which dimensions are dynamic. 113 for (int64 i = 0; i < a.rank(); ++i) { 114 own_shape.set_dynamic_dimension(i, a.is_dynamic_dimension(i)); 115 } 116 return ShapeUtil::Equal(own_shape, a); 117 } 118 119 // Given that "this" is the target index of a reshape from `input_shape` 120 // to `output_shape`, returns the source index. 121 Index SourceIndexOfReshape(const Shape& output_shape, 122 const Shape& input_shape, 123 llvm::IRBuilder<>* builder) const; 124 125 // Returns the index into the source operand from which a slice operation 126 // selects a value to be placed into index "this". The slice is described 127 // by starting indices `starts` and stride values `strides`. 128 // 129 // Precondition: "this" is an index into a slice whose operand shape is 130 // `operand_shape`. 131 Index SourceIndexOfSlice(const Shape& operand_shape, 132 absl::Span<const int64> starts, 133 absl::Span<const int64> strides, 134 llvm::IRBuilder<>* builder) const; 135 136 // Given that "this" is the target index of a transpose from `operand_shape` 137 // to `shape` with the given dimension mapping, returns the source index. 138 Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, 139 absl::Span<const int64> dimension_mapping, 140 llvm::IRBuilder<>* builder) const; 141 142 // Given that "this" is the target index of a bitcast from `operand_shape` 143 // to `shape`, returns the source index. 144 Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, 145 llvm::IRBuilder<>* builder) const; 146 147 // Given that "this" is the target index of a broadcast from `operand_shape` 148 // to `shape` with the given dimension mapping, returns the source index. 149 Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, 150 absl::Span<const int64> dimension_mapping, 151 llvm::IRBuilder<>* builder) const; 152 153 // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and 154 // returns the index into the sole dimension 0 of the new shape. 155 llvm::Value* Linearize(absl::Span<const int64> dimensions, 156 llvm::IRBuilder<>* builder) const; 157 GetType()158 llvm::Type* GetType() const { return index_type_; } 159 GetConstantWithIndexType(int64 c)160 llvm::Constant* GetConstantWithIndexType(int64 c) const { 161 // The LLVM function makes sure that the value can be represented by the 162 // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const 163 // APInt &V). 164 return llvm::ConstantInt::get(index_type_, c); 165 } 166 167 private: 168 // Constructs an index from both a multi-dimensional index and a linear 169 // index. 'shape' is the shape on which the index is used. 'index_type' is 170 // the type of the index. 171 // 172 // Precondition: "shape" has a layout. 173 Index(absl::Span<llvm::Value* const> multidim, llvm::Value* linear, 174 const Shape& shape, llvm::Type* index_type); 175 176 void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear, 177 const Shape& shape, llvm::IRBuilder<>* b) const; 178 179 std::vector<llvm::Value*> multidim_; 180 181 // These values are purely for efficiency; `multidim_` is enough to find the 182 // element at a given `Index`, but if a loop is emitted with a linear index 183 // space, that linear index can be saved in `linear_`, and the layout and 184 // dimensions of the shape the loop was emitted for in `layout_` and 185 // `dims_`, and if the `Index` is used in another array, and its layout and 186 // dimensions match, the linear index can be used, sparing the cost of 187 // computing `multidim_`, which LLVM DCE could potentially so delete. 188 // Modifying `multidim_` after construction nullifies `linear_`, lest it 189 // be used wrongly, as it would be valid no more. 190 // If a loop is emitted with a multidimensional index space, `linear_` would 191 // be null and `layout_` and `dims_` would be ignored. 192 llvm::Value* linear_ = nullptr; 193 Layout layout_; 194 std::vector<int64> dims_; 195 196 llvm::Type* index_type_; 197 }; 198 199 // Default constructor. Constructs an IrArray in a null status. IrArray()200 IrArray() : base_ptr_(nullptr) {} 201 202 // Construct an IrArray with the given base pointer and shape. base_ptr is a 203 // pointer type pointing to the first element(lowest address) of the array. 204 IrArray(llvm::Value* base_ptr, Shape shape); 205 206 // Default implementations of copying and moving. 207 IrArray(IrArray&& other) = default; 208 IrArray(const IrArray& other) = default; 209 IrArray& operator=(IrArray&& other) = default; 210 IrArray& operator=(const IrArray& other) = default; 211 GetBasePointer()212 llvm::Value* GetBasePointer() const { return base_ptr_; } GetElementLlvmType()213 llvm::Type* GetElementLlvmType() const { return element_type_; } 214 GetShape()215 const Shape& GetShape() const { return shape_; } 216 217 // Emit a sequence of instructions to compute the address of the element in 218 // the given array at the given index. Returns the address of the element as 219 // an LLVM Value. 220 // 221 // The optional name is useful for debugging when looking at 222 // the emitted LLVM IR. 223 llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, 224 absl::string_view name = "", 225 bool use_linear_index = true) const; 226 227 // Attach metadata this IrArray instance knows about to "instruction". 228 void AnnotateLoadStoreInstructionWithMetadata( 229 llvm::Instruction* instruction) const; 230 231 // Emit IR to read an array element at the given index. Returns the read 232 // result (effectively, a Value loaded from memory). This method seamlessly 233 // handles scalar shapes by broadcasting their value to all indices (index is 234 // ignored). 235 // 236 // The optional name is useful for debugging when looking at 237 // the emitted LLVM IR. 238 // 'use_linear_index' can be used to specify whether the linear index (if 239 // available) or the multi-dimensional index should be used. 240 llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, 241 absl::string_view name = "", 242 bool use_linear_index = true) const; 243 244 // Emit IR to write the given value to the array element at the given index. 245 // 'use_linear_index' can be used to specify whether the linear index (if 246 // available) or the multi-dimensional index should be used. 247 void EmitWriteArrayElement(const Index& index, llvm::Value* value, 248 llvm::IRBuilder<>* b, 249 bool use_linear_index = true) const; 250 251 // Returns a new IrArray whose shape is "new_shape" and base pointer is a 252 // bitcast of the base pointer of "this" IrArray. 253 // 'use_linear_index' can be used to specify whether the linear index (if 254 // available) or the multi-dimensional index should be used. 255 IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; 256 AddAliasScopeMetadata(llvm::MDNode * alias_scope)257 void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { 258 CHECK_NE(alias_scope, nullptr); 259 AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope); 260 } 261 AddNoaliasMetadata(llvm::MDNode * noalias)262 void AddNoaliasMetadata(llvm::MDNode* noalias) { 263 CHECK_NE(noalias, nullptr); 264 AddMetadata(llvm::LLVMContext::MD_noalias, noalias); 265 } 266 267 // Promises LLVM that the data pointed to by this IrArray never changes after 268 // it's first loaded. 269 // 270 // The temporal scope of this promise is the "whole program" from LLVM's point 271 // of view, but how this translates to HLOs differs between backends. 272 // 273 // In the single-threaded CPU backend, we emit one function that 274 // runs all the HLOs in sequence, so the whole program is the whole HLO 275 // module. 276 // 277 // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO 278 // in the entry computation). From LLVM's perspective, launching a new kernel 279 // is like launching a new program, and so the whole program is one top-level 280 // HLO. Since the scope of the promise is smaller than in the CPU backend, we 281 // can mark more things as invariant in the GPU backend. 282 // 283 // Marking loads as invariant is particularly helpful on GPUs because 284 // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's 285 // __ldg intrinsic). These loads use a special cache, and can be 286 // significantly faster than regular loads. MarkInvariantOverWholeProgram(llvm::LLVMContext * context)287 void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { 288 if (is_invariant_) { 289 return; 290 } 291 is_invariant_ = true; 292 AddMetadata(llvm::LLVMContext::MD_invariant_load, 293 llvm::MDNode::get(*context, {})); 294 } 295 metadata()296 const std::map<int, llvm::MDNode*>& metadata() const { return metadata_; } 297 298 private: 299 // Add the specified LLVM IR metadata to loads/stores associated with this 300 // IrArray. AddMetadata(int kind,llvm::MDNode * md)301 void AddMetadata(int kind, llvm::MDNode* md) { 302 InsertOrDie(&metadata_, kind, md); 303 } 304 305 // Address of the base of the array as an LLVM Value. 306 llvm::Value* base_ptr_; 307 308 // The LLVM type of the elements in the array. 309 llvm::Type* element_type_; 310 311 // Shape of the XLA array. 312 Shape shape_; 313 314 // The list of key/value pairs used when attaching metadata to emitted 315 // loads/stores for this array. They keys are the metadata kinds and the 316 // values are the metadata nodes. 317 std::map<int, llvm::MDNode*> metadata_; 318 319 bool is_invariant_ = false; 320 }; 321 322 } // namespace llvm_ir 323 } // namespace xla 324 325 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_ 326