1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/importexport/functiondef_export.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22 #include "mlir/IR/OperationSupport.h" // from @llvm-project
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/ir/dialect.h"
28 #include "tensorflow/core/ir/importexport/convert_attributes.h"
29 #include "tensorflow/core/ir/importexport/convert_types.h"
30 #include "tensorflow/core/ir/importexport/graphdef_export.h"
31 #include "tensorflow/core/ir/ops.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/statusor.h"
34
35 using tensorflow::FunctionDef;
36 using tensorflow::OpDef;
37 using tensorflow::OpDef_AttrDef;
38 using tensorflow::Status;
39 using tensorflow::errors::InvalidArgument;
40
41 #define DEBUG_TYPE "mlir-to-graphdef"
42
43 namespace mlir {
44 namespace tfg {
45
46 // Compute the name to use in FunctionDef for a given Value (either the result
47 // of an operation or a block operand if a function argument) and store the
48 // result in the provided name string. The `control_ty` is the instance of the
49 // `ControlType` to compare against and detect a control dependency case.
GetValueName(Value operand,Type control_ty)50 static tensorflow::StatusOr<std::string> GetValueName(Value operand,
51 Type control_ty) {
52 bool is_control = (operand.getType() == control_ty);
53 OpResult op_result = operand.dyn_cast<OpResult>();
54 if (!op_result) {
55 BlockArgument block_operand = operand.dyn_cast<BlockArgument>();
56 int arg_num = block_operand.getArgNumber();
57
58 // Function arguments are coming as pair: the even are the actual tensors
59 // while the odd position are the associated control input.
60 std::string name;
61 if (is_control) name = "^";
62 DictionaryAttr arg_attrs = function_interface_impl::getArgAttrDict(
63 block_operand.getParentBlock()->getParentOp(), arg_num - is_control);
64 if (!arg_attrs)
65 return InvalidArgument("Missing attribute for argument #", arg_num);
66 StringAttr arg_name = arg_attrs.getAs<StringAttr>("tfg.name");
67 if (!arg_name)
68 return InvalidArgument(
69 "Can't export graph with missing op-name for function parameter #",
70 arg_num);
71 absl::StrAppend(&name, arg_name.getValue().str());
72 return name;
73 }
74 GetResultOp get_result = op_result.getDefiningOp<GetResultOp>();
75 Operation *producer;
76 if (is_control) {
77 producer = op_result.getDefiningOp();
78 } else {
79 if (!get_result)
80 return InvalidArgument("Missing get_result operation as input");
81 producer = get_result.value().getDefiningOp();
82 if (!producer)
83 return InvalidArgument("Expect a tfg operation as input to GetResultOp");
84 }
85
86 auto name_attr =
87 producer->getAttrOfType<StringAttr>(TFGraphDialect::getNameAttrKey());
88 if (!name_attr)
89 return InvalidArgument("Can't export graph with missing op-name");
90
91 std::string name;
92 if (is_control) name = "^";
93 absl::StrAppend(&name, name_attr.getValue().str());
94 if (get_result)
95 absl::StrAppend(&name, ":", get_result.name().str(), ":",
96 get_result.number());
97 return name;
98 }
99
100 // Export a function argument or returned value as an ArgDef entry.
101 // If arg_def_attrs is provided, it is populated with the extra attributes
102 // converted from MLIR to AttrValue proto representation. This is useful only
103 // for Function arguments to populate the `arg_attr` field.
104 //
ExportArgDef(OpDef::ArgDef * arg,DictionaryAttr arg_attrs,FunctionDef::ArgAttrs * arg_def_attrs=nullptr)105 static Status ExportArgDef(OpDef::ArgDef *arg, DictionaryAttr arg_attrs,
106 FunctionDef::ArgAttrs *arg_def_attrs = nullptr) {
107 StringAttr arg_name = arg_attrs.getAs<StringAttr>("tfg.name");
108 if (!arg_name) return InvalidArgument("Missing \"tfg.name\" attribute");
109 arg->set_name(arg_name.getValue().str());
110 StringAttr description = arg_attrs.getAs<StringAttr>("tfg.description");
111 if (description) arg->set_description(description.getValue().str());
112 TypeAttr input_type = arg_attrs.getAs<TypeAttr>("tfg.type");
113 if (input_type) {
114 tensorflow::DataType dtype;
115 TF_RETURN_IF_ERROR(ConvertToDataType(input_type.getValue(), &dtype));
116 arg->set_type(dtype);
117 }
118 if (StringAttr type_attr = arg_attrs.getAs<StringAttr>("tfg.type_attr"))
119 arg->set_type_attr(type_attr.getValue().str());
120 if (StringAttr number_attr = arg_attrs.getAs<StringAttr>("tfg.number_attr"))
121 arg->set_number_attr(number_attr.getValue().str());
122 if (StringAttr type_list_attr =
123 arg_attrs.getAs<StringAttr>("tfg.type_list_attr"))
124 arg->set_type_attr(type_list_attr.getValue().str());
125 if (auto full_type = arg_attrs.getAs<tf_type::FullTypeAttr>(
126 "tfg.experimental_full_type")) {
127 TF_ASSIGN_OR_RETURN(*arg->mutable_experimental_full_type(),
128 ConvertAttribute(full_type));
129 }
130 TF_RETURN_IF_ERROR(
131 ConvertHandleData(arg_attrs.getAs<ArrayAttr>("tfg.handle_data"), arg));
132 if (UnitAttr number_attr = arg_attrs.getAs<UnitAttr>("tfg.is_ref"))
133 arg->set_is_ref(true);
134
135 auto sig_arg_attrs = arg_attrs.getAs<DictionaryAttr>("tfg.arg_attrs");
136 if (arg_def_attrs && sig_arg_attrs) {
137 TF_RETURN_IF_ERROR(ConvertAttributes(
138 sig_arg_attrs.getValue(), /*attrs_to_ignore=*/{},
139 /*remove_ref_type=*/false, arg_def_attrs->mutable_attr()));
140 }
141 return ::tensorflow::OkStatus();
142 }
143
ConvertGenericFunctionToFunctionDef(GraphFuncOp func_op)144 tensorflow::StatusOr<FunctionDef> ConvertGenericFunctionToFunctionDef(
145 GraphFuncOp func_op) {
146 if (!func_op.generic())
147 return InvalidArgument(
148 "Expected a generic function in ConvertGenericFunctionToFunctionDef");
149 auto control_ty = tfg::ControlType::get(func_op.getContext());
150 auto *tfg_dialect = cast<TFGraphDialect>(func_op->getDialect());
151
152 FunctionDef fdef;
153 for (Operation &op : func_op.getBody()->without_terminator()) {
154 if (op.getDialect() != tfg_dialect)
155 return InvalidArgument("Non tfg op encountered when exporting function");
156
157 if (isa<GetResultOp>(&op)) continue;
158
159 TF_RETURN_IF_ERROR(ConvertToNodeDef(
160 &op, fdef.add_node_def(), tfg_dialect,
161 [&](Value value) { return GetValueName(value, control_ty); }));
162 }
163
164 const std::string func_name = func_op.getName().str();
165 OpDef *signature = fdef.mutable_signature();
166 signature->set_name(func_name);
167 if (func_op->getAttr("is_stateful")) signature->set_is_stateful(true);
168 if (auto description = func_op->getAttrOfType<StringAttr>("description"))
169 signature->set_description(description.getValue().str());
170
171 if (auto attrs = func_op->getAttrOfType<DictionaryAttr>("tfg.func_attrs")) {
172 for (NamedAttribute attr : attrs) {
173 OpDef_AttrDef *func_attr = signature->add_attr();
174 func_attr->set_name(attr.getName().str());
175 DictionaryAttr dict_attr = attr.getValue().dyn_cast<DictionaryAttr>();
176 if (!dict_attr) return InvalidArgument("Expects dict attribute");
177 if (StringAttr type = dict_attr.getAs<StringAttr>("function_type"))
178 func_attr->set_type(type.getValue().str());
179 if (Attribute default_value = dict_attr.get("default_value")) {
180 TF_ASSIGN_OR_RETURN((*func_attr->mutable_default_value()),
181 ConvertAttribute(default_value));
182 }
183 if (StringAttr description = dict_attr.getAs<StringAttr>("description"))
184 func_attr->set_description(description.getValue().str());
185 if (IntegerAttr minimum = dict_attr.getAs<IntegerAttr>("minimum")) {
186 func_attr->set_minimum(minimum.getInt());
187 func_attr->set_has_minimum(true);
188 }
189 if (Attribute allowed_values = dict_attr.get("allowed_values")) {
190 TF_ASSIGN_OR_RETURN((*func_attr->mutable_allowed_values()),
191 ConvertAttribute(allowed_values));
192 }
193 }
194 }
195
196 if (auto control_outputs =
197 func_op->getAttrOfType<ArrayAttr>("control_output")) {
198 for (Attribute attr : control_outputs) {
199 StringAttr output = attr.dyn_cast<StringAttr>();
200 if (!output)
201 return InvalidArgument(
202 "Can't export function with non-string \"control_output\" "
203 "attribute entry");
204 signature->add_control_output(output.getValue().str());
205 }
206 }
207
208 // Convert the function argument into an OpDef::ArgDef in the signature.
209 ArrayAttr args_attr = func_op.getAllArgAttrs();
210 for (int arg_num : llvm::seq<int>(0, func_op.getNumArguments())) {
211 // Odd position are just for control dependencies.
212 if (arg_num % 2) continue;
213 OpDef::ArgDef *arg = signature->add_input_arg();
214 if (arg_num >= args_attr.size())
215 return InvalidArgument("Can't export function ", func_op.getName().str(),
216 " because missing attributes for arg #", arg_num);
217 DictionaryAttr arg_attrs = args_attr[arg_num].cast<DictionaryAttr>();
218 FunctionDef::ArgAttrs func_def_arg_attrs;
219 TF_RETURN_WITH_CONTEXT_IF_ERROR(
220 ExportArgDef(arg, arg_attrs, &func_def_arg_attrs),
221 " when exporting argument ", arg_num, " for function ",
222 func_op.getName().str());
223
224 // On top of the signature, function arguments can have attribute directul
225 // on the FunctionDef.
226 if (!func_def_arg_attrs.attr().empty())
227 (*fdef.mutable_arg_attr())[arg_num / 2] = std::move(func_def_arg_attrs);
228 }
229
230 // Handle the results now.
231 // An ArgDef entry needs to be constructed for all non-control returned value,
232 // and a mapping from the output name to the signature is also recorded in the
233 // FunctionDef.
234 auto return_op =
235 llvm::cast<tfg::ReturnOp>(func_op.getBody()->getTerminator());
236 ArrayAttr results_attr = func_op.getAllResultAttrs();
237 for (auto &indexed_result : llvm::enumerate(return_op->getOperands())) {
238 int res_num = indexed_result.index();
239 if (res_num >= results_attr.size())
240 return InvalidArgument("Can't export function ", func_op.getName().str(),
241 " because missing attributes for result #",
242 res_num);
243 auto res_attrs = results_attr[res_num].cast<DictionaryAttr>();
244 auto name = res_attrs.getAs<StringAttr>("tfg.name");
245 if (!name)
246 return InvalidArgument(
247 "Can't export function ", func_op.getName().str(),
248 " because missing \"tfg.name\" attribute for result #", res_num);
249
250 Value ret_val = indexed_result.value();
251 if (ret_val.getType() == control_ty) {
252 // When we return a control dependency, it is not really a returned value
253 // but it is added to the `control_ret` field of the FunctionDef.
254 TF_ASSIGN_OR_RETURN(std::string ret_name,
255 GetValueName(ret_val, control_ty));
256 fdef.mutable_control_ret()->insert(
257 {name.getValue().str(), StringRef(ret_name).drop_front().str()});
258 continue;
259 }
260 // Tensor results are turned into an ArgDef in the `output_arg` field.
261 OpDef::ArgDef *output = signature->add_output_arg();
262 TF_RETURN_WITH_CONTEXT_IF_ERROR(ExportArgDef(output, res_attrs),
263 " when exporting result ", res_num,
264 " for function ", func_op.getName().str());
265
266 // The `ret` field of the FunctionDef keeps a mapping of the returned value
267 // name to the entried in the FunctionDef signature.
268 TF_ASSIGN_OR_RETURN(std::string ret_name,
269 GetValueName(ret_val, control_ty));
270 fdef.mutable_ret()->insert({name.getValue().str(), ret_name});
271 }
272
273 // Handled the `resource_arg_unique_id` entries. At the moment it is
274 // represented as two vectors of integers which are expected of the same
275 // length.
276 auto unique_ids_keys = func_op->getAttrOfType<DenseIntElementsAttr>(
277 "resource_arg_unique_ids_keys");
278 if (unique_ids_keys) {
279 auto unique_ids_values = func_op->getAttrOfType<DenseIntElementsAttr>(
280 "resource_arg_unique_ids_values");
281 if (!unique_ids_values)
282 return InvalidArgument(
283 "Can't export function ", func_name,
284 " because \"resource_arg_unique_ids_keys\" attribute is present "
285 "but "
286 "\"resource_arg_unique_ids_values\" is missing");
287 if (unique_ids_keys.size() != unique_ids_values.size())
288 return InvalidArgument(
289 "Can't export function ", func_name,
290 " because \"resource_arg_unique_ids_keys\" array does not have the "
291 "same size as \"resource_arg_unique_ids_values\"");
292
293 auto *unique_ids_map = fdef.mutable_resource_arg_unique_id();
294 for (auto key_value : llvm::zip(unique_ids_keys.getValues<int32_t>(),
295 unique_ids_values.getValues<int32_t>()))
296 (*unique_ids_map)[std::get<0>(key_value)] = std::get<1>(key_value);
297 }
298
299 // Finally the dialect attributes (prefixed by `tf.` in general) are converted
300 // as-is and stored on the `attr` field of the FunctionDef.
301 SmallVector<NamedAttribute, 8> funcAttrs(func_op->getDialectAttrs());
302 TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, {"tfg.func_attrs"},
303 /*remove_ref_type=*/false,
304 fdef.mutable_attr()));
305 return fdef;
306 }
307
308 } // namespace tfg
309 } // namespace mlir
310