• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_common.h"
16 
17 #include "mlir/IR/Builders.h"  // from @llvm-project
18 
19 namespace tfrt {
20 namespace fallback_common {
21 
GetExecuteOpAttrsCommon(mlir::MLIRContext * context,llvm::ArrayRef<mlir::Attribute> op_attr_array,llvm::SmallVectorImpl<std::pair<llvm::StringRef,mlir::Attribute>> * op_attrs)22 void GetExecuteOpAttrsCommon(
23     mlir::MLIRContext *context, llvm::ArrayRef<mlir::Attribute> op_attr_array,
24     llvm::SmallVectorImpl<std::pair<llvm::StringRef, mlir::Attribute>>
25         *op_attrs) {
26   assert(op_attrs);
27   op_attrs->clear();
28 
29   mlir::Builder builder(context);
30   for (auto iter : op_attr_array) {
31     auto key_value = iter.cast<mlir::ArrayAttr>().getValue();
32     llvm::StringRef key = key_value[0].cast<mlir::StringAttr>().getValue();
33     mlir::Attribute value = key_value[1];
34     op_attrs->push_back({key, value});
35   }
36 }
37 
ParseExecuteOpCommon(mlir::OpAsmParser & parser,mlir::Builder & builder,mlir::OperationState & result,mlir::Type tensor_type,const ParseExecuteOpOptions & options)38 mlir::ParseResult ParseExecuteOpCommon(mlir::OpAsmParser &parser,
39                                        mlir::Builder &builder,
40                                        mlir::OperationState &result,
41                                        mlir::Type tensor_type,
42                                        const ParseExecuteOpOptions &options) {
43   auto chain_type = builder.getType<compiler::ChainType>();
44 
45   mlir::IntegerAttr op_key;
46   mlir::IntegerAttr cost;
47   mlir::StringAttr device;
48   mlir::StringAttr op_name;
49   llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> in_chains;
50   llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
51   mlir::NamedAttrList op_attrs;
52   mlir::NamedAttrList op_func_attrs;
53   auto loc = parser.getNameLoc();
54 
55   if (options.has_chain &&
56       parser.parseOperandList(in_chains,
57                               /*requiredOperandCount=*/1,
58                               mlir::OpAsmParser::Delimiter::Paren))
59     return mlir::failure();
60 
61   if (options.has_key &&
62       (parser.parseKeyword("key") || parser.parseLParen() ||
63        parser.parseAttribute(op_key, "op_key", result.attributes) ||
64        parser.parseRParen()))
65     return mlir::failure();
66 
67   if (options.has_cost &&
68       (parser.parseKeyword("cost") || parser.parseLParen() ||
69        parser.parseAttribute(cost, "_tfrt_cost", result.attributes) ||
70        parser.parseRParen()))
71     return mlir::failure();
72 
73   if (options.has_device &&
74       (parser.parseKeyword("device") || parser.parseLParen() ||
75        parser.parseAttribute(device, "device", result.attributes) ||
76        parser.parseRParen()))
77     return mlir::failure();
78 
79   if (parser.parseAttribute(op_name, "op_name", result.attributes) ||
80       parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
81       parser.parseOptionalAttrDict(op_attrs) ||
82       parser.parseOptionalAttrDict(op_func_attrs))
83     return mlir::failure();
84 
85   int64_t num_results = 0;
86   if (succeeded(parser.parseOptionalColon())) {
87     mlir::IntegerAttr attr;
88     mlir::NamedAttrList attrs;
89     if (failed(parser.parseAttribute(attr, "num_results", attrs)))
90       return mlir::failure();
91     num_results = attr.getValue().getSExtValue();
92   }
93 
94   llvm::SmallVector<mlir::Type, 4> operand_types;
95   if (options.has_chain) operand_types.push_back(chain_type);
96   if (parser.resolveOperands(in_chains, operand_types, loc, result.operands) ||
97       parser.resolveOperands(operands, tensor_type, result.operands))
98     return mlir::failure();
99 
100   if (options.has_chain) result.types.push_back(chain_type);
101   result.types.append(num_results, tensor_type);
102 
103   llvm::SmallVector<mlir::Attribute, 4> op_attr_array;
104   for (const auto &key_value : op_attrs) {
105     auto key = builder.getStringAttr(key_value.first.strref());
106     auto value = key_value.second;
107     op_attr_array.push_back(builder.getArrayAttr({key, value}));
108   }
109 
110   result.attributes.push_back(
111       builder.getNamedAttr("op_attrs", builder.getArrayAttr(op_attr_array)));
112 
113   // TODO(tfrt-devs): support func attributes in tfrt_fallback_sync.
114   if (options.has_func_attr) {
115     llvm::SmallVector<mlir::Attribute, 4> op_func_attr_array;
116     for (const auto &key_value : op_func_attrs) {
117       auto key = builder.getStringAttr(key_value.first.strref());
118       auto value = key_value.second;
119       op_func_attr_array.push_back(builder.getArrayAttr({key, value}));
120     }
121 
122     result.attributes.push_back(builder.getNamedAttr(
123         "op_func_attrs", builder.getArrayAttr(op_func_attr_array)));
124   }
125 
126   return mlir::success();
127 }
128 
129 }  // namespace fallback_common
130 }  // namespace tfrt
131