• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_OPDEFS_TFRT_FALLBACK_COMMON_H_
16 #define TENSORFLOW_CORE_RUNTIME_FALLBACK_OPDEFS_TFRT_FALLBACK_COMMON_H_
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
22 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
23 
24 namespace tfrt {
25 namespace fallback_common {
26 
27 template <typename OpTy>
VerifyExecuteOpCommon(OpTy op)28 mlir::LogicalResult VerifyExecuteOpCommon(OpTy op) {
29   auto op_attr_array = op.op_attrs().getValue();
30   for (auto op_attr : op_attr_array) {
31     auto key_value = op_attr.template dyn_cast<mlir::ArrayAttr>();
32     if (!key_value || key_value.getValue().size() != 2 ||
33         !key_value.getValue()[0].template isa<mlir::StringAttr>())
34       return op.emitOpError() << "each op_attr should be a key-value pair, "
35                                  "where the key is a string";
36   }
37   return mlir::success();
38 }
39 
40 template <typename OpTy>
VerifyFallbackExecuteOp(OpTy op)41 mlir::LogicalResult VerifyFallbackExecuteOp(OpTy op) {
42   auto result = VerifyExecuteOpCommon(op);
43   if (failed(result)) return result;
44 
45   // Verify function attributes.
46   auto op_func_attr_array = op.op_func_attrs().getValue();
47   for (auto op_attr : op_func_attr_array) {
48     auto key_value = op_attr.template dyn_cast<mlir::ArrayAttr>();
49     if (!key_value || key_value.getValue().size() != 2 ||
50         !key_value.getValue()[0].template isa<mlir::StringAttr>() ||
51         !key_value.getValue()[1].template isa<mlir::StringAttr>())
52       return op.emitOpError() << "each op_func_attr should be a key-value "
53                                  "pair, where both the key and the value are "
54                                  "strings";
55   }
56   return mlir::success();
57 }
58 
59 template <typename OpTy>
PrintExecuteOpFuncAttribute(mlir::OpAsmPrinter & p,OpTy op)60 void PrintExecuteOpFuncAttribute(mlir::OpAsmPrinter &p, OpTy op) {
61   auto op_func_attrs = op.op_func_attrs();
62   if (!op_func_attrs.empty()) {
63     auto print_key_value = [&](mlir::Attribute attr) {
64       auto key_value = attr.cast<mlir::ArrayAttr>().getValue();
65       auto key = key_value[0];
66       auto value = key_value[1];
67 
68       p << key.cast<mlir::StringAttr>().getValue();
69       p << " = ";
70       p << value;
71     };
72 
73     auto op_func_attr_array = op_func_attrs.getValue();
74     p << " {";
75     llvm::interleaveComma(op_func_attr_array, p, print_key_value);
76     p << '}';
77   }
78 }
79 
80 template <typename OpTy>
PrintExecuteOpCommon(mlir::OpAsmPrinter & p,OpTy op)81 void PrintExecuteOpCommon(mlir::OpAsmPrinter &p, OpTy op) {
82   auto op_attrs = op.op_attrs();
83   if (!op_attrs.empty()) {
84     auto print_key_value = [&](mlir::Attribute attr) {
85       auto key_value = attr.cast<mlir::ArrayAttr>().getValue();
86       auto key = key_value[0];
87       auto value = key_value[1];
88 
89       p << key.cast<mlir::StringAttr>().getValue();
90       p << " = ";
91       p << value;
92     };
93 
94     auto op_attr_array = op_attrs.getValue();
95     p << " {";
96     llvm::interleaveComma(op_attr_array, p, print_key_value);
97     p << '}';
98   }
99 }
100 
101 void GetExecuteOpAttrsCommon(
102     mlir::MLIRContext *context, llvm::ArrayRef<mlir::Attribute> op_attr_array,
103     llvm::SmallVectorImpl<std::pair<llvm::StringRef, mlir::Attribute>>
104         *op_attrs);
105 
106 struct ParseExecuteOpOptions {
107   bool has_chain = false;
108   bool has_key = false;
109   bool has_device = false;
110   bool has_func_attr = false;
111   bool has_cost = false;
112 };
113 
114 mlir::ParseResult ParseExecuteOpCommon(mlir::OpAsmParser &parser,
115                                        mlir::Builder &builder,
116                                        mlir::OperationState &result,
117                                        mlir::Type tensor_type,
118                                        const ParseExecuteOpOptions &options);
119 }  // namespace fallback_common
120 }  // namespace tfrt
121 
122 #endif  // TENSORFLOW_CORE_RUNTIME_FALLBACK_OPDEFS_TFRT_FALLBACK_COMMON_H_
123