• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/importexport/functiondef_export.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
22 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/ir/dialect.h"
28 #include "tensorflow/core/ir/importexport/convert_attributes.h"
29 #include "tensorflow/core/ir/importexport/convert_types.h"
30 #include "tensorflow/core/ir/importexport/graphdef_export.h"
31 #include "tensorflow/core/ir/ops.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/statusor.h"
34 
35 using tensorflow::FunctionDef;
36 using tensorflow::OpDef;
37 using tensorflow::OpDef_AttrDef;
38 using tensorflow::Status;
39 using tensorflow::errors::InvalidArgument;
40 
41 #define DEBUG_TYPE "mlir-to-graphdef"
42 
43 namespace mlir {
44 namespace tfg {
45 
46 // Compute the name to use in FunctionDef for a given Value (either the result
47 // of an operation or a block operand if a function argument) and store the
48 // result in the provided name string. The `control_ty` is the instance of the
49 // `ControlType` to compare against and detect a control dependency case.
GetValueName(Value operand,Type control_ty)50 static tensorflow::StatusOr<std::string> GetValueName(Value operand,
51                                                       Type control_ty) {
52   bool is_control = (operand.getType() == control_ty);
53   OpResult op_result = operand.dyn_cast<OpResult>();
54   if (!op_result) {
55     BlockArgument block_operand = operand.dyn_cast<BlockArgument>();
56     int arg_num = block_operand.getArgNumber();
57 
58     // Function arguments are coming as pair: the even are the actual tensors
59     // while the odd position are the associated control input.
60     std::string name;
61     if (is_control) name = "^";
62     DictionaryAttr arg_attrs = function_interface_impl::getArgAttrDict(
63         block_operand.getParentBlock()->getParentOp(), arg_num - is_control);
64     if (!arg_attrs)
65       return InvalidArgument("Missing attribute for argument #", arg_num);
66     StringAttr arg_name = arg_attrs.getAs<StringAttr>("tfg.name");
67     if (!arg_name)
68       return InvalidArgument(
69           "Can't export graph with missing op-name for function parameter #",
70           arg_num);
71     absl::StrAppend(&name, arg_name.getValue().str());
72     return name;
73   }
74   GetResultOp get_result = op_result.getDefiningOp<GetResultOp>();
75   Operation *producer;
76   if (is_control) {
77     producer = op_result.getDefiningOp();
78   } else {
79     if (!get_result)
80       return InvalidArgument("Missing get_result operation as input");
81     producer = get_result.value().getDefiningOp();
82     if (!producer)
83       return InvalidArgument("Expect a tfg operation as input to GetResultOp");
84   }
85 
86   auto name_attr =
87       producer->getAttrOfType<StringAttr>(TFGraphDialect::getNameAttrKey());
88   if (!name_attr)
89     return InvalidArgument("Can't export graph with missing op-name");
90 
91   std::string name;
92   if (is_control) name = "^";
93   absl::StrAppend(&name, name_attr.getValue().str());
94   if (get_result)
95     absl::StrAppend(&name, ":", get_result.name().str(), ":",
96                     get_result.number());
97   return name;
98 }
99 
100 // Export a function argument or returned value as an ArgDef entry.
101 // If arg_def_attrs is provided, it is populated with the extra attributes
102 // converted from MLIR to AttrValue proto representation. This is useful only
103 // for Function arguments to populate the `arg_attr` field.
104 //
ExportArgDef(OpDef::ArgDef * arg,DictionaryAttr arg_attrs,FunctionDef::ArgAttrs * arg_def_attrs=nullptr)105 static Status ExportArgDef(OpDef::ArgDef *arg, DictionaryAttr arg_attrs,
106                            FunctionDef::ArgAttrs *arg_def_attrs = nullptr) {
107   StringAttr arg_name = arg_attrs.getAs<StringAttr>("tfg.name");
108   if (!arg_name) return InvalidArgument("Missing \"tfg.name\" attribute");
109   arg->set_name(arg_name.getValue().str());
110   StringAttr description = arg_attrs.getAs<StringAttr>("tfg.description");
111   if (description) arg->set_description(description.getValue().str());
112   TypeAttr input_type = arg_attrs.getAs<TypeAttr>("tfg.type");
113   if (input_type) {
114     tensorflow::DataType dtype;
115     TF_RETURN_IF_ERROR(ConvertToDataType(input_type.getValue(), &dtype));
116     arg->set_type(dtype);
117   }
118   if (StringAttr type_attr = arg_attrs.getAs<StringAttr>("tfg.type_attr"))
119     arg->set_type_attr(type_attr.getValue().str());
120   if (StringAttr number_attr = arg_attrs.getAs<StringAttr>("tfg.number_attr"))
121     arg->set_number_attr(number_attr.getValue().str());
122   if (StringAttr type_list_attr =
123           arg_attrs.getAs<StringAttr>("tfg.type_list_attr"))
124     arg->set_type_attr(type_list_attr.getValue().str());
125   if (auto full_type = arg_attrs.getAs<tf_type::FullTypeAttr>(
126           "tfg.experimental_full_type")) {
127     TF_ASSIGN_OR_RETURN(*arg->mutable_experimental_full_type(),
128                         ConvertAttribute(full_type));
129   }
130   TF_RETURN_IF_ERROR(
131       ConvertHandleData(arg_attrs.getAs<ArrayAttr>("tfg.handle_data"), arg));
132   if (UnitAttr number_attr = arg_attrs.getAs<UnitAttr>("tfg.is_ref"))
133     arg->set_is_ref(true);
134 
135   auto sig_arg_attrs = arg_attrs.getAs<DictionaryAttr>("tfg.arg_attrs");
136   if (arg_def_attrs && sig_arg_attrs) {
137     TF_RETURN_IF_ERROR(ConvertAttributes(
138         sig_arg_attrs.getValue(), /*attrs_to_ignore=*/{},
139         /*remove_ref_type=*/false, arg_def_attrs->mutable_attr()));
140   }
141   return ::tensorflow::OkStatus();
142 }
143 
ConvertGenericFunctionToFunctionDef(GraphFuncOp func_op)144 tensorflow::StatusOr<FunctionDef> ConvertGenericFunctionToFunctionDef(
145     GraphFuncOp func_op) {
146   if (!func_op.generic())
147     return InvalidArgument(
148         "Expected a generic function in ConvertGenericFunctionToFunctionDef");
149   auto control_ty = tfg::ControlType::get(func_op.getContext());
150   auto *tfg_dialect = cast<TFGraphDialect>(func_op->getDialect());
151 
152   FunctionDef fdef;
153   for (Operation &op : func_op.getBody()->without_terminator()) {
154     if (op.getDialect() != tfg_dialect)
155       return InvalidArgument("Non tfg op encountered when exporting function");
156 
157     if (isa<GetResultOp>(&op)) continue;
158 
159     TF_RETURN_IF_ERROR(ConvertToNodeDef(
160         &op, fdef.add_node_def(), tfg_dialect,
161         [&](Value value) { return GetValueName(value, control_ty); }));
162   }
163 
164   const std::string func_name = func_op.getName().str();
165   OpDef *signature = fdef.mutable_signature();
166   signature->set_name(func_name);
167   if (func_op->getAttr("is_stateful")) signature->set_is_stateful(true);
168   if (auto description = func_op->getAttrOfType<StringAttr>("description"))
169     signature->set_description(description.getValue().str());
170 
171   if (auto attrs = func_op->getAttrOfType<DictionaryAttr>("tfg.func_attrs")) {
172     for (NamedAttribute attr : attrs) {
173       OpDef_AttrDef *func_attr = signature->add_attr();
174       func_attr->set_name(attr.getName().str());
175       DictionaryAttr dict_attr = attr.getValue().dyn_cast<DictionaryAttr>();
176       if (!dict_attr) return InvalidArgument("Expects dict attribute");
177       if (StringAttr type = dict_attr.getAs<StringAttr>("function_type"))
178         func_attr->set_type(type.getValue().str());
179       if (Attribute default_value = dict_attr.get("default_value")) {
180         TF_ASSIGN_OR_RETURN((*func_attr->mutable_default_value()),
181                             ConvertAttribute(default_value));
182       }
183       if (StringAttr description = dict_attr.getAs<StringAttr>("description"))
184         func_attr->set_description(description.getValue().str());
185       if (IntegerAttr minimum = dict_attr.getAs<IntegerAttr>("minimum")) {
186         func_attr->set_minimum(minimum.getInt());
187         func_attr->set_has_minimum(true);
188       }
189       if (Attribute allowed_values = dict_attr.get("allowed_values")) {
190         TF_ASSIGN_OR_RETURN((*func_attr->mutable_allowed_values()),
191                             ConvertAttribute(allowed_values));
192       }
193     }
194   }
195 
196   if (auto control_outputs =
197           func_op->getAttrOfType<ArrayAttr>("control_output")) {
198     for (Attribute attr : control_outputs) {
199       StringAttr output = attr.dyn_cast<StringAttr>();
200       if (!output)
201         return InvalidArgument(
202             "Can't export function with non-string \"control_output\" "
203             "attribute entry");
204       signature->add_control_output(output.getValue().str());
205     }
206   }
207 
208   // Convert the function argument into an OpDef::ArgDef in the signature.
209   ArrayAttr args_attr = func_op.getAllArgAttrs();
210   for (int arg_num : llvm::seq<int>(0, func_op.getNumArguments())) {
211     // Odd position are just for control dependencies.
212     if (arg_num % 2) continue;
213     OpDef::ArgDef *arg = signature->add_input_arg();
214     if (arg_num >= args_attr.size())
215       return InvalidArgument("Can't export function ", func_op.getName().str(),
216                              " because missing attributes for arg #", arg_num);
217     DictionaryAttr arg_attrs = args_attr[arg_num].cast<DictionaryAttr>();
218     FunctionDef::ArgAttrs func_def_arg_attrs;
219     TF_RETURN_WITH_CONTEXT_IF_ERROR(
220         ExportArgDef(arg, arg_attrs, &func_def_arg_attrs),
221         " when exporting argument ", arg_num, " for function ",
222         func_op.getName().str());
223 
224     // On top of the signature, function arguments can have attribute directul
225     // on the FunctionDef.
226     if (!func_def_arg_attrs.attr().empty())
227       (*fdef.mutable_arg_attr())[arg_num / 2] = std::move(func_def_arg_attrs);
228   }
229 
230   // Handle the results now.
231   // An ArgDef entry needs to be constructed for all non-control returned value,
232   // and a mapping from the output name to the signature is also recorded in the
233   // FunctionDef.
234   auto return_op =
235       llvm::cast<tfg::ReturnOp>(func_op.getBody()->getTerminator());
236   ArrayAttr results_attr = func_op.getAllResultAttrs();
237   for (auto &indexed_result : llvm::enumerate(return_op->getOperands())) {
238     int res_num = indexed_result.index();
239     if (res_num >= results_attr.size())
240       return InvalidArgument("Can't export function ", func_op.getName().str(),
241                              " because missing attributes for result #",
242                              res_num);
243     auto res_attrs = results_attr[res_num].cast<DictionaryAttr>();
244     auto name = res_attrs.getAs<StringAttr>("tfg.name");
245     if (!name)
246       return InvalidArgument(
247           "Can't export function ", func_op.getName().str(),
248           " because missing \"tfg.name\" attribute for result #", res_num);
249 
250     Value ret_val = indexed_result.value();
251     if (ret_val.getType() == control_ty) {
252       // When we return a control dependency, it is not really a returned value
253       // but it is added to the `control_ret` field of the FunctionDef.
254       TF_ASSIGN_OR_RETURN(std::string ret_name,
255                           GetValueName(ret_val, control_ty));
256       fdef.mutable_control_ret()->insert(
257           {name.getValue().str(), StringRef(ret_name).drop_front().str()});
258       continue;
259     }
260     // Tensor results are turned into an ArgDef in the `output_arg` field.
261     OpDef::ArgDef *output = signature->add_output_arg();
262     TF_RETURN_WITH_CONTEXT_IF_ERROR(ExportArgDef(output, res_attrs),
263                                     " when exporting result ", res_num,
264                                     " for function ", func_op.getName().str());
265 
266     // The `ret` field of the FunctionDef keeps a mapping of the returned value
267     // name to the entried in the FunctionDef signature.
268     TF_ASSIGN_OR_RETURN(std::string ret_name,
269                         GetValueName(ret_val, control_ty));
270     fdef.mutable_ret()->insert({name.getValue().str(), ret_name});
271   }
272 
273   // Handled the `resource_arg_unique_id` entries. At the moment it is
274   // represented as two vectors of integers which are expected of the same
275   // length.
276   auto unique_ids_keys = func_op->getAttrOfType<DenseIntElementsAttr>(
277       "resource_arg_unique_ids_keys");
278   if (unique_ids_keys) {
279     auto unique_ids_values = func_op->getAttrOfType<DenseIntElementsAttr>(
280         "resource_arg_unique_ids_values");
281     if (!unique_ids_values)
282       return InvalidArgument(
283           "Can't export function ", func_name,
284           " because \"resource_arg_unique_ids_keys\" attribute is present "
285           "but "
286           "\"resource_arg_unique_ids_values\" is missing");
287     if (unique_ids_keys.size() != unique_ids_values.size())
288       return InvalidArgument(
289           "Can't export function ", func_name,
290           " because \"resource_arg_unique_ids_keys\" array does not have the "
291           "same size as \"resource_arg_unique_ids_values\"");
292 
293     auto *unique_ids_map = fdef.mutable_resource_arg_unique_id();
294     for (auto key_value : llvm::zip(unique_ids_keys.getValues<int32_t>(),
295                                     unique_ids_values.getValues<int32_t>()))
296       (*unique_ids_map)[std::get<0>(key_value)] = std::get<1>(key_value);
297   }
298 
299   // Finally the dialect attributes (prefixed by `tf.` in general) are converted
300   // as-is and stored on the `attr` field of the FunctionDef.
301   SmallVector<NamedAttribute, 8> funcAttrs(func_op->getDialectAttrs());
302   TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, {"tfg.func_attrs"},
303                                        /*remove_ref_type=*/false,
304                                        fdef.mutable_attr()));
305   return fdef;
306 }
307 
308 }  // namespace tfg
309 }  // namespace mlir
310