1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_common.h"
16
17 #include "mlir/IR/Builders.h" // from @llvm-project
18
19 namespace tfrt {
20 namespace fallback_common {
21
GetExecuteOpAttrsCommon(mlir::MLIRContext * context,llvm::ArrayRef<mlir::Attribute> op_attr_array,llvm::SmallVectorImpl<std::pair<llvm::StringRef,mlir::Attribute>> * op_attrs)22 void GetExecuteOpAttrsCommon(
23 mlir::MLIRContext *context, llvm::ArrayRef<mlir::Attribute> op_attr_array,
24 llvm::SmallVectorImpl<std::pair<llvm::StringRef, mlir::Attribute>>
25 *op_attrs) {
26 assert(op_attrs);
27 op_attrs->clear();
28
29 mlir::Builder builder(context);
30 for (auto iter : op_attr_array) {
31 auto key_value = iter.cast<mlir::ArrayAttr>().getValue();
32 llvm::StringRef key = key_value[0].cast<mlir::StringAttr>().getValue();
33 mlir::Attribute value = key_value[1];
34 op_attrs->push_back({key, value});
35 }
36 }
37
ParseExecuteOpCommon(mlir::OpAsmParser & parser,mlir::Builder & builder,mlir::OperationState & result,mlir::Type tensor_type,const ParseExecuteOpOptions & options)38 mlir::ParseResult ParseExecuteOpCommon(mlir::OpAsmParser &parser,
39 mlir::Builder &builder,
40 mlir::OperationState &result,
41 mlir::Type tensor_type,
42 const ParseExecuteOpOptions &options) {
43 auto chain_type = builder.getType<compiler::ChainType>();
44
45 mlir::IntegerAttr op_key;
46 mlir::IntegerAttr cost;
47 mlir::StringAttr device;
48 mlir::StringAttr op_name;
49 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> in_chains;
50 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
51 mlir::NamedAttrList op_attrs;
52 mlir::NamedAttrList op_func_attrs;
53 auto loc = parser.getNameLoc();
54
55 if (options.has_chain &&
56 parser.parseOperandList(in_chains,
57 /*requiredOperandCount=*/1,
58 mlir::OpAsmParser::Delimiter::Paren))
59 return mlir::failure();
60
61 if (options.has_key &&
62 (parser.parseKeyword("key") || parser.parseLParen() ||
63 parser.parseAttribute(op_key, "op_key", result.attributes) ||
64 parser.parseRParen()))
65 return mlir::failure();
66
67 if (options.has_cost &&
68 (parser.parseKeyword("cost") || parser.parseLParen() ||
69 parser.parseAttribute(cost, "_tfrt_cost", result.attributes) ||
70 parser.parseRParen()))
71 return mlir::failure();
72
73 if (options.has_device &&
74 (parser.parseKeyword("device") || parser.parseLParen() ||
75 parser.parseAttribute(device, "device", result.attributes) ||
76 parser.parseRParen()))
77 return mlir::failure();
78
79 if (parser.parseAttribute(op_name, "op_name", result.attributes) ||
80 parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
81 parser.parseOptionalAttrDict(op_attrs) ||
82 parser.parseOptionalAttrDict(op_func_attrs))
83 return mlir::failure();
84
85 int64_t num_results = 0;
86 if (succeeded(parser.parseOptionalColon())) {
87 mlir::IntegerAttr attr;
88 mlir::NamedAttrList attrs;
89 if (failed(parser.parseAttribute(attr, "num_results", attrs)))
90 return mlir::failure();
91 num_results = attr.getValue().getSExtValue();
92 }
93
94 llvm::SmallVector<mlir::Type, 4> operand_types;
95 if (options.has_chain) operand_types.push_back(chain_type);
96 if (parser.resolveOperands(in_chains, operand_types, loc, result.operands) ||
97 parser.resolveOperands(operands, tensor_type, result.operands))
98 return mlir::failure();
99
100 if (options.has_chain) result.types.push_back(chain_type);
101 result.types.append(num_results, tensor_type);
102
103 llvm::SmallVector<mlir::Attribute, 4> op_attr_array;
104 for (const auto &key_value : op_attrs) {
105 auto key = builder.getStringAttr(key_value.first.strref());
106 auto value = key_value.second;
107 op_attr_array.push_back(builder.getArrayAttr({key, value}));
108 }
109
110 result.attributes.push_back(
111 builder.getNamedAttr("op_attrs", builder.getArrayAttr(op_attr_array)));
112
113 // TODO(tfrt-devs): support func attributes in tfrt_fallback_sync.
114 if (options.has_func_attr) {
115 llvm::SmallVector<mlir::Attribute, 4> op_func_attr_array;
116 for (const auto &key_value : op_func_attrs) {
117 auto key = builder.getStringAttr(key_value.first.strref());
118 auto value = key_value.second;
119 op_func_attr_array.push_back(builder.getArrayAttr({key, value}));
120 }
121
122 result.attributes.push_back(builder.getNamedAttr(
123 "op_func_attrs", builder.getArrayAttr(op_func_attr_array)));
124 }
125
126 return mlir::success();
127 }
128
129 } // namespace fallback_common
130 } // namespace tfrt
131