1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_TYPES_H_ 17 #define XLA_RUNTIME_TYPES_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <utility> 22 23 #include "llvm/ADT/ArrayRef.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/Support/Errc.h" 26 #include "llvm/Support/ErrorOr.h" 27 #include "llvm/Support/ExtensibleRTTI.h" 28 #include "tfrt/dtype/dtype.h" // from @tf_runtime 29 30 namespace xla { 31 namespace runtime { 32 33 //===----------------------------------------------------------------------===// 34 // Canonical XLA runtime types for the executable arguments. 35 //===----------------------------------------------------------------------===// 36 37 // Types supported by the compiled function signature. We do rely on the LLVM 38 // style RTTI (https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html) to avoid 39 // dependency on the MLIR types at runtime, because we don't want to depend 40 // on any of the compiler implementation details at runtime and we want to 41 // support lightweight loading and execution of AOT compiled programs. 42 // 43 // We rely on the RTTI for the open class hierarchies, because we want to allow 44 // users to define their own types for the arguments. 45 // 46 // If the type can be passed to the compiled function as an argument or returned 47 // as a result, it must define its own ABI. The ABI is defined by the MLIR to 48 // LLVM lowering pipeline and the runtime integration (see `runtime.h`). 49 class Type : public llvm::RTTIExtends<Type, llvm::RTTIRoot> { 50 public: 51 static constexpr char ID = 0; // NOLINT 52 53 // Arguments to compiled functions passed as a set of pointers. For example 54 // memref descriptor passed in as a set of pointers to data, sizes and 55 // strides. See `Argument::Pack` implementation for details (in `argument.h`). 56 struct ArgumentAbi { 57 size_t num_ptrs; 58 }; 59 60 // Compiled function returns results by writing into the pre-allocated storage 61 // of the given size with the requested alignment. Runtime pre-allocates 62 // memory required for all results in the call frame. 63 struct ResultAbi { 64 size_t size; 65 66 // TODO(ezhulenev): Add alignment to the result ABI. Alignment is an 67 // important part of the result ABI that we ignore today. It all doesn't 68 // crash only because all results happen to have a size that is multiple of 69 // 8 bytes, and because of that all of the results are properly aligned. 70 // Results memory layout in the call frame should take in account base 71 // pointer alignment and alignment requirements of all results. 72 }; 73 74 // Returns an Abi if the type can be used as an argument. AsArgument()75 virtual llvm::ErrorOr<ArgumentAbi> AsArgument() const { 76 return llvm::errc::not_supported; 77 } 78 79 // Returns an Abi if the type can be returned as a result. AsResult()80 virtual llvm::ErrorOr<ResultAbi> AsResult() const { 81 return llvm::errc::not_supported; 82 } 83 84 virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0; 85 86 protected: 87 Type() = default; 88 }; 89 90 inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Type& type) { 91 return type.print(os); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Async Token type corresponding to the mlir::async::TokenType 96 //===----------------------------------------------------------------------===// 97 98 class AsyncTokenType : public llvm::RTTIExtends<AsyncTokenType, Type> { 99 public: 100 static constexpr char ID = 0; // NOLINT 101 102 llvm::ErrorOr<ResultAbi> AsResult() const final; 103 104 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 105 }; 106 107 //===----------------------------------------------------------------------===// 108 // Async Value type corresponding to the mlir::async::ValueType. 109 //===----------------------------------------------------------------------===// 110 111 class AsyncValueType : public llvm::RTTIExtends<AsyncValueType, Type> { 112 public: 113 static constexpr char ID = 0; // NOLINT 114 AsyncValueType(std::unique_ptr<Type> value_type)115 explicit AsyncValueType(std::unique_ptr<Type> value_type) 116 : value_type_(std::move(value_type)) {} 117 value_type()118 const Type& value_type() const { return *value_type_; } 119 120 llvm::ErrorOr<ResultAbi> AsResult() const final; 121 122 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 123 124 private: 125 std::unique_ptr<Type> value_type_; 126 }; 127 128 //===----------------------------------------------------------------------===// 129 // Ranked Tensor type corresponding to the mlir::RankedTensorType. 130 //===----------------------------------------------------------------------===// 131 132 class RankedTensorType : public llvm::RTTIExtends<RankedTensorType, Type> { 133 public: 134 static constexpr char ID = 0; // NOLINT 135 static constexpr int64_t kDynamicSize = -1; 136 IsDynamic(int64_t dim)137 static constexpr bool IsDynamic(int64_t dim) { return dim == kDynamicSize; } 138 RankedTensorType(llvm::ArrayRef<int64_t> sizes,tfrt::DType element_type)139 RankedTensorType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type) 140 : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {} 141 sizes()142 llvm::ArrayRef<int64_t> sizes() const { return sizes_; } rank()143 unsigned rank() const { return sizes_.size(); } element_type()144 tfrt::DType element_type() const { return element_type_; } 145 146 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 147 148 private: 149 llvm::SmallVector<int64_t> sizes_; 150 tfrt::DType element_type_; 151 }; 152 153 //===----------------------------------------------------------------------===// 154 // Unranked Tensor type corresponding to the mlir::UnrankedTensorType. 155 //===----------------------------------------------------------------------===// 156 157 class UnrankedTensorType : public llvm::RTTIExtends<UnrankedTensorType, Type> { 158 public: 159 static constexpr char ID = 0; // NOLINT 160 UnrankedTensorType(tfrt::DType element_type)161 explicit UnrankedTensorType(tfrt::DType element_type) 162 : element_type_(element_type) {} 163 element_type()164 tfrt::DType element_type() const { return element_type_; } 165 166 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 167 168 private: 169 tfrt::DType element_type_; 170 }; 171 172 //===----------------------------------------------------------------------===// 173 // Ranked Memref type corresponding to the mlir::MemrefType. 174 //===----------------------------------------------------------------------===// 175 176 class MemrefType : public llvm::RTTIExtends<MemrefType, Type> { 177 public: 178 static constexpr char ID = 0; // NOLINT 179 static constexpr int64_t kDynamicSize = -1; 180 IsDynamic(int64_t dim)181 static constexpr bool IsDynamic(int64_t dim) { return dim == kDynamicSize; } 182 MemrefType(llvm::ArrayRef<int64_t> sizes,tfrt::DType element_type)183 MemrefType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type) 184 : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {} 185 sizes()186 llvm::ArrayRef<int64_t> sizes() const { return sizes_; } rank()187 unsigned rank() const { return sizes_.size(); } element_type()188 tfrt::DType element_type() const { return element_type_; } 189 190 llvm::ErrorOr<ArgumentAbi> AsArgument() const final; 191 llvm::ErrorOr<ResultAbi> AsResult() const final; 192 193 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 194 195 private: 196 llvm::SmallVector<int64_t> sizes_; 197 tfrt::DType element_type_; 198 }; 199 200 //===----------------------------------------------------------------------===// 201 // Unranked Memref type corresponding to the mlir::UnrankedMemrefType. 202 //===----------------------------------------------------------------------===// 203 204 class UnrankedMemrefType : public llvm::RTTIExtends<UnrankedMemrefType, Type> { 205 public: 206 static constexpr char ID = 0; // NOLINT 207 UnrankedMemrefType(tfrt::DType element_type)208 explicit UnrankedMemrefType(tfrt::DType element_type) 209 : element_type_(element_type) {} 210 element_type()211 tfrt::DType element_type() const { return element_type_; } 212 213 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 214 215 private: 216 tfrt::DType element_type_; 217 }; 218 219 //===----------------------------------------------------------------------===// 220 // Corresponds to the RT dialect's KernelContextType. 221 //===----------------------------------------------------------------------===// 222 223 class KernelContextOperandType 224 : public llvm::RTTIExtends<KernelContextOperandType, Type> { 225 public: 226 static constexpr char ID = 0; // NOLINT 227 228 llvm::ErrorOr<ArgumentAbi> AsArgument() const final; 229 230 llvm::raw_ostream& print(llvm::raw_ostream& os) const final; 231 }; 232 233 //===----------------------------------------------------------------------===// 234 // Compiled function signature type corresponding to the mlir::FunctionType. 235 //===----------------------------------------------------------------------===// 236 237 class FunctionType { 238 public: operand(unsigned index)239 const Type* operand(unsigned index) const { return operands_[index].get(); } result(unsigned index)240 const Type* result(unsigned index) const { return results_[index].get(); } 241 num_operands()242 unsigned num_operands() const { return operands_.size(); } num_results()243 unsigned num_results() const { return results_.size(); } 244 FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands,llvm::SmallVector<std::unique_ptr<Type>> results)245 FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands, 246 llvm::SmallVector<std::unique_ptr<Type>> results) 247 : operands_(std::move(operands)), results_(std::move(results)) {} 248 249 private: 250 llvm::SmallVector<std::unique_ptr<Type>> operands_; 251 llvm::SmallVector<std::unique_ptr<Type>> results_; 252 }; 253 254 } // namespace runtime 255 } // namespace xla 256 257 #endif // XLA_RUNTIME_TYPES_H_ 258