• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_TYPES_H_
17 #define XLA_RUNTIME_TYPES_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Errc.h"
26 #include "llvm/Support/ErrorOr.h"
27 #include "llvm/Support/ExtensibleRTTI.h"
28 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
29 
30 namespace xla {
31 namespace runtime {
32 
33 //===----------------------------------------------------------------------===//
34 // Canonical XLA runtime types for the executable arguments.
35 //===----------------------------------------------------------------------===//
36 
37 // Types supported by the compiled function signature. We do rely on the LLVM
38 // style RTTI (https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html) to avoid
39 // dependency on the MLIR types at runtime, because we don't want to depend
40 // on any of the compiler implementation details at runtime and we want to
41 // support lightweight loading and execution of AOT compiled programs.
42 //
43 // We rely on the RTTI for the open class hierarchies, because we want to allow
44 // users to define their own types for the arguments.
45 //
46 // If the type can be passed to the compiled function as an argument or returned
47 // as a result, it must define its own ABI. The ABI is defined by the MLIR to
48 // LLVM lowering pipeline and the runtime integration (see `runtime.h`).
49 class Type : public llvm::RTTIExtends<Type, llvm::RTTIRoot> {
50  public:
51   static constexpr char ID = 0;  // NOLINT
52 
53   // Arguments to compiled functions passed as a set of pointers. For example
54   // memref descriptor passed in as a set of pointers to data, sizes and
55   // strides. See `Argument::Pack` implementation for details (in `argument.h`).
56   struct ArgumentAbi {
57     size_t num_ptrs;
58   };
59 
60   // Compiled function returns results by writing into the pre-allocated storage
61   // of the given size with the requested alignment. Runtime pre-allocates
62   // memory required for all results in the call frame.
63   struct ResultAbi {
64     size_t size;
65 
66     // TODO(ezhulenev): Add alignment to the result ABI. Alignment is an
67     // important part of the result ABI that we ignore today. It all doesn't
68     // crash only because all results happen to have a size that is multiple of
69     // 8 bytes, and because of that all of the results are properly aligned.
70     // Results memory layout in the call frame should take in account base
71     // pointer alignment and alignment requirements of all results.
72   };
73 
74   // Returns an Abi if the type can be used as an argument.
AsArgument()75   virtual llvm::ErrorOr<ArgumentAbi> AsArgument() const {
76     return llvm::errc::not_supported;
77   }
78 
79   // Returns an Abi if the type can be returned as a result.
AsResult()80   virtual llvm::ErrorOr<ResultAbi> AsResult() const {
81     return llvm::errc::not_supported;
82   }
83 
84   virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0;
85 
86  protected:
87   Type() = default;
88 };
89 
90 inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Type& type) {
91   return type.print(os);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Async Token type corresponding to the mlir::async::TokenType
96 //===----------------------------------------------------------------------===//
97 
98 class AsyncTokenType : public llvm::RTTIExtends<AsyncTokenType, Type> {
99  public:
100   static constexpr char ID = 0;  // NOLINT
101 
102   llvm::ErrorOr<ResultAbi> AsResult() const final;
103 
104   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
105 };
106 
107 //===----------------------------------------------------------------------===//
108 // Async Value type corresponding to the mlir::async::ValueType.
109 //===----------------------------------------------------------------------===//
110 
111 class AsyncValueType : public llvm::RTTIExtends<AsyncValueType, Type> {
112  public:
113   static constexpr char ID = 0;  // NOLINT
114 
AsyncValueType(std::unique_ptr<Type> value_type)115   explicit AsyncValueType(std::unique_ptr<Type> value_type)
116       : value_type_(std::move(value_type)) {}
117 
value_type()118   const Type& value_type() const { return *value_type_; }
119 
120   llvm::ErrorOr<ResultAbi> AsResult() const final;
121 
122   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
123 
124  private:
125   std::unique_ptr<Type> value_type_;
126 };
127 
128 //===----------------------------------------------------------------------===//
129 // Ranked Tensor type corresponding to the mlir::RankedTensorType.
130 //===----------------------------------------------------------------------===//
131 
132 class RankedTensorType : public llvm::RTTIExtends<RankedTensorType, Type> {
133  public:
134   static constexpr char ID = 0;  // NOLINT
135   static constexpr int64_t kDynamicSize = -1;
136 
IsDynamic(int64_t dim)137   static constexpr bool IsDynamic(int64_t dim) { return dim == kDynamicSize; }
138 
RankedTensorType(llvm::ArrayRef<int64_t> sizes,tfrt::DType element_type)139   RankedTensorType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
140       : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
141 
sizes()142   llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
rank()143   unsigned rank() const { return sizes_.size(); }
element_type()144   tfrt::DType element_type() const { return element_type_; }
145 
146   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
147 
148  private:
149   llvm::SmallVector<int64_t> sizes_;
150   tfrt::DType element_type_;
151 };
152 
153 //===----------------------------------------------------------------------===//
154 // Unranked Tensor type corresponding to the mlir::UnrankedTensorType.
155 //===----------------------------------------------------------------------===//
156 
157 class UnrankedTensorType : public llvm::RTTIExtends<UnrankedTensorType, Type> {
158  public:
159   static constexpr char ID = 0;  // NOLINT
160 
UnrankedTensorType(tfrt::DType element_type)161   explicit UnrankedTensorType(tfrt::DType element_type)
162       : element_type_(element_type) {}
163 
element_type()164   tfrt::DType element_type() const { return element_type_; }
165 
166   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
167 
168  private:
169   tfrt::DType element_type_;
170 };
171 
172 //===----------------------------------------------------------------------===//
173 // Ranked Memref type corresponding to the mlir::MemrefType.
174 //===----------------------------------------------------------------------===//
175 
176 class MemrefType : public llvm::RTTIExtends<MemrefType, Type> {
177  public:
178   static constexpr char ID = 0;  // NOLINT
179   static constexpr int64_t kDynamicSize = -1;
180 
IsDynamic(int64_t dim)181   static constexpr bool IsDynamic(int64_t dim) { return dim == kDynamicSize; }
182 
MemrefType(llvm::ArrayRef<int64_t> sizes,tfrt::DType element_type)183   MemrefType(llvm::ArrayRef<int64_t> sizes, tfrt::DType element_type)
184       : sizes_(sizes.begin(), sizes.end()), element_type_(element_type) {}
185 
sizes()186   llvm::ArrayRef<int64_t> sizes() const { return sizes_; }
rank()187   unsigned rank() const { return sizes_.size(); }
element_type()188   tfrt::DType element_type() const { return element_type_; }
189 
190   llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
191   llvm::ErrorOr<ResultAbi> AsResult() const final;
192 
193   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
194 
195  private:
196   llvm::SmallVector<int64_t> sizes_;
197   tfrt::DType element_type_;
198 };
199 
200 //===----------------------------------------------------------------------===//
201 // Unranked Memref type corresponding to the mlir::UnrankedMemrefType.
202 //===----------------------------------------------------------------------===//
203 
204 class UnrankedMemrefType : public llvm::RTTIExtends<UnrankedMemrefType, Type> {
205  public:
206   static constexpr char ID = 0;  // NOLINT
207 
UnrankedMemrefType(tfrt::DType element_type)208   explicit UnrankedMemrefType(tfrt::DType element_type)
209       : element_type_(element_type) {}
210 
element_type()211   tfrt::DType element_type() const { return element_type_; }
212 
213   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
214 
215  private:
216   tfrt::DType element_type_;
217 };
218 
219 //===----------------------------------------------------------------------===//
220 // Corresponds to the RT dialect's KernelContextType.
221 //===----------------------------------------------------------------------===//
222 
223 class KernelContextOperandType
224     : public llvm::RTTIExtends<KernelContextOperandType, Type> {
225  public:
226   static constexpr char ID = 0;  // NOLINT
227 
228   llvm::ErrorOr<ArgumentAbi> AsArgument() const final;
229 
230   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
231 };
232 
233 //===----------------------------------------------------------------------===//
234 // Compiled function signature type corresponding to the mlir::FunctionType.
235 //===----------------------------------------------------------------------===//
236 
237 class FunctionType {
238  public:
operand(unsigned index)239   const Type* operand(unsigned index) const { return operands_[index].get(); }
result(unsigned index)240   const Type* result(unsigned index) const { return results_[index].get(); }
241 
num_operands()242   unsigned num_operands() const { return operands_.size(); }
num_results()243   unsigned num_results() const { return results_.size(); }
244 
FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands,llvm::SmallVector<std::unique_ptr<Type>> results)245   FunctionType(llvm::SmallVector<std::unique_ptr<Type>> operands,
246                llvm::SmallVector<std::unique_ptr<Type>> results)
247       : operands_(std::move(operands)), results_(std::move(results)) {}
248 
249  private:
250   llvm::SmallVector<std::unique_ptr<Type>> operands_;
251   llvm::SmallVector<std::unique_ptr<Type>> results_;
252 };
253 
254 }  // namespace runtime
255 }  // namespace xla
256 
257 #endif  // XLA_RUNTIME_TYPES_H_
258