• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
27 #include "mlir/IR/Attributes.h"  // TF:llvm-project
28 #include "mlir/IR/Function.h"  // TF:llvm-project
29 #include "mlir/IR/Identifier.h"  // TF:llvm-project
30 #include "mlir/IR/Location.h"  // TF:llvm-project
31 #include "mlir/IR/Module.h"  // TF:llvm-project
32 #include "mlir/IR/Operation.h"  // TF:llvm-project
33 #include "mlir/IR/OperationSupport.h"  // TF:llvm-project
34 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // TF:llvm-project
36 #include "mlir/Support/DebugStringHelper.h"  // TF:llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/core/framework/attr_value.pb.h"
45 #include "tensorflow/core/framework/graph.pb.h"
46 #include "tensorflow/core/framework/graph_to_functiondef.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/op.h"
50 #include "tensorflow/core/framework/tensor.pb.h"
51 #include "tensorflow/core/framework/tensor_shape.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/graph/algorithm.h"
54 #include "tensorflow/core/graph/graph.h"
55 #include "tensorflow/core/graph/graph_constructor.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 
59 namespace tensorflow {
60 namespace {
61 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)62 Status ConvertLocation(mlir::Location inst_loc,
63                        NodeDef::ExperimentalDebugInfo* debug_info) {
64   if (auto call_site = inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
65     if (auto name_loc = call_site.getCallee().dyn_cast<mlir::NameLoc>()) {
66       debug_info->add_original_node_names(name_loc.getName().c_str());
67     }
68   } else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
69     auto locations = fused.getLocations();
70     if (locations.size() <= 1)
71       return errors::InvalidArgument("expected experimental debuf info.");
72     // skip the first one, which is the name of the node_def.
73     for (int i = 0; i < locations.size() - 1; ++i) {
74       TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
75     }
76   }
77   return Status::OK();
78 }
79 
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)80 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
81   value->set_b(attr.getValue());
82   return Status::OK();
83 }
84 
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)85 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
86   value->set_i(attr.getInt());
87   return Status::OK();
88 }
89 
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)90 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
91   value->set_f(attr.getValueAsDouble());
92   return Status::OK();
93 }
94 
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)95 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
96   return ConvertToTensorProto(attr, value->mutable_tensor());
97 }
98 
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)99 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
100   absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
101   switch (mangling_util::GetMangledKind(attr_value)) {
102     case mangling_util::MangledKind::kUnknown: {
103       value->set_s(std::string(attr_value));
104       return Status::OK();
105     }
106     case mangling_util::MangledKind::kDataType: {
107       DataType dtype;
108       TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
109       value->set_type(dtype);
110       return Status::OK();
111     }
112     case mangling_util::MangledKind::kTensorShape:
113       TF_RETURN_IF_ERROR(
114           mangling_util::DemangleShape(attr_value, value->mutable_shape()));
115       return Status::OK();
116     default:
117       return errors::Unimplemented("Mangled string couldn't be handled!");
118   }
119   return Status::OK();
120 }
121 
ConvertAttribute(mlir::Type type,AttrValue * value)122 Status ConvertAttribute(mlir::Type type, AttrValue* value) {
123   DataType dtype;
124   TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
125   value->set_type(dtype);
126   return Status::OK();
127 }
128 
ConvertAttribute(const mlir::TypeAttr & type,AttrValue * value)129 Status ConvertAttribute(const mlir::TypeAttr& type, AttrValue* value) {
130   return ConvertAttribute(type.getValue(), value);
131 }
132 
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)133 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
134   value->clear_value();
135   return Status::OK();
136 }
137 
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)138 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
139   value->mutable_func()->set_name(std::string(attr.getValue()));
140   return Status::OK();
141 }
142 
ConvertAttribute(const mlir::ArrayAttr & attr,AttrValue * value)143 Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) {
144   auto* list = value->mutable_list();
145   for (mlir::Attribute a : attr.getValue()) {
146     if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
147       list->add_b(attr.getValue());
148     } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
149       list->add_i(attr.getInt());
150     } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
151       list->add_f(attr.getValueAsDouble());
152     } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
153       AttrValue nested_value;
154       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
155       switch (nested_value.value_case()) {
156         case AttrValue::kS:
157           list->add_s(nested_value.s());
158           break;
159         case AttrValue::kType:
160           list->add_type(nested_value.type());
161           break;
162         case AttrValue::kShape:
163           *list->add_shape() = nested_value.shape();
164           break;
165         default:
166           return errors::Unimplemented("Unhandled nested attribute!");
167       }
168     } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
169       TensorProto tensor;
170       TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
171       *list->add_tensor() = tensor;
172     } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
173       AttrValue attr_val;
174       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
175       *list->add_func() = attr_val.func();
176     } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
177       AttrValue attr_val;
178       // For type attributes, we only propagate the element type.
179       mlir::Type elt_type = attr.getValue();
180       if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
181         elt_type = shaped_type.getElementType();
182       }
183       TF_RETURN_IF_ERROR(ConvertAttribute(elt_type, &attr_val));
184       list->add_type(attr_val.type());
185     } else {
186       return errors::Unimplemented("Unhandled attribute!");
187     }
188   }
189   return Status::OK();
190 }
191 
192 // Updates NodeDef constructed out of an MLIR If op to map it to either
193 // TensorFlow StatelessIf or If op depending on the additional attribute.
UpdateCompositeIfOp(NodeDef * node_def)194 void UpdateCompositeIfOp(NodeDef* node_def) {
195   auto it = node_def->mutable_attr()->find("is_stateless");
196   if (it != node_def->attr().end()) {
197     if (it->second.b()) {
198       *node_def->mutable_op() = "StatelessIf";
199     }
200     node_def->mutable_attr()->erase(it);
201   }
202 }
203 
204 // Updates NodeDef constructed out of an MLIR While op to map it to either
205 // TensorFlow StatelessWhile or While op depending on the additional attribute.
UpdateCompositeWhileOp(NodeDef * node_def)206 void UpdateCompositeWhileOp(NodeDef* node_def) {
207   auto it = node_def->mutable_attr()->find("is_stateless");
208   if (it != node_def->attr().end()) {
209     if (it->second.b()) {
210       *node_def->mutable_op() = "StatelessWhile";
211     }
212     node_def->mutable_attr()->erase(it);
213   }
214 }
215 
216 // Returns true if the executor/control dialect op should map to Ref node in
217 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
218 // type. For executor dialect NextIteration it uses the 2nd operand type. For
219 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
220 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)221 static bool IsRefTypeControlOp(mlir::Operation* op) {
222   if (auto next_iter_sink =
223           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
224     return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
225         .isa<mlir::TF::TensorFlowRefType>();
226 
227   auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
228   if (!op_name_or_status.ok()) return false;
229 
230   auto op_name = op_name_or_status.ConsumeValueOrDie();
231   if (op_name.equals("NextIteration"))
232     return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
233         .isa<mlir::TF::TensorFlowRefType>();
234 
235   if (op_name.equals("Enter") || op_name.equals("Exit") ||
236       op_name.equals("Switch") || op_name.equals("Merge")) {
237     return getElementTypeOrSelf(op->getResult(0).getType())
238         .isa<mlir::TF::TensorFlowRefType>();
239   }
240   return false;
241 }
242 
243 }  // anonymous namespace
244 
GetTensorFlowOpName(llvm::StringRef op_name)245 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
246   // When being converted to MLIR, some prefixes and suffixes are added to the
247   // operation types, and we have to remove them when converting the
248   // operations back to a graph:
249   // - "_tf.", "tf." or "tf_executor." : every operation type has this prefix.
250   // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
251   // don't need to consider ".source"/".Source" because the nodes with this
252   // suffix are skipped by the caller and will not be added to the graph.
253   if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
254       !op_name.consume_front("tf_executor.")) {
255     return errors::FailedPrecondition("op node '", op_name.str(),
256                                       "' was not a TF op!");
257   }
258   // Control dialect NextIteration sink ends with ".sink" and Executor dialect
259   // NextIteration sink ends with ".Sink".
260   if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
261   return op_name;
262 }
263 
GetOperationNodeDef(const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,mlir::Operation * inst,llvm::StringRef name)264 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
265     const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
266     mlir::Operation* inst, llvm::StringRef name) {
267   auto node_def = absl::make_unique<NodeDef>();
268   // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
269   // mapping back from the inputs to the input arguments.
270 
271   llvm::SmallString<64> op_name;
272   if (IsLegacyCallInstruction(inst)) {
273     // The op_name is the name of the function.
274     op_name.append(
275         inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
276     // Remove the attribute from the instruction as it is already converted to
277     // op_name.
278     auto attr_id = mlir::Identifier::get("f", inst->getContext());
279     inst->removeAttr(attr_id);
280   } else {
281     // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
282     // as well. For example there is Enter and RefEnter op. RefEnter forwards
283     // the input ref buffer to output. However both Enter and RefEnter are
284     // mapped to tf_executor::EnterOp during import and then to _tf.Enter op in
285     // control dialect. Check if it is a Ref op to correctly map to the
286     // TensorFlow Graph op.
287     if (IsRefTypeControlOp(inst)) op_name = "Ref";
288     TF_ASSIGN_OR_RETURN(auto tf_name,
289                         GetTensorFlowOpName(inst->getName().getStringRef()));
290     op_name.append(tf_name);
291   }
292 
293   node_def->set_name(name.str());
294   node_def->set_op(std::string(op_name.str()));
295 
296   // Add inputs to the NodeDef based on the number of operands. This is required
297   // as later when edges are added to the Node using Graph::AddEdge the
298   // associated NodeDef is not updated.
299   for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
300     node_def->add_input();
301   }
302   if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
303     node_def->set_device(std::string(attr.getValue()));
304   }
305 
306   // Add the node attributes.
307   TF_RETURN_WITH_CONTEXT_IF_ERROR(
308       ConvertAttributes(inst->getAttrs(), attrs_to_ignore,
309                         node_def->mutable_attr()),
310       "while converting attributes for node: ", name.str());
311 
312   // Add the node debug info.
313   TF_RETURN_IF_ERROR(ConvertLocation(
314       inst->getLoc(), node_def->mutable_experimental_debug_info()));
315 
316   if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get());
317   if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get());
318 
319   return node_def;
320 }
321 
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,AttrValueMap * values)322 Status ConvertAttributes(
323     const llvm::ArrayRef<mlir::NamedAttribute> attrs,
324     const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
325     AttrValueMap* values) {
326   AttrValueMap func_call_attrs;
327   for (const mlir::NamedAttribute& named_attr : attrs) {
328     auto name_strref = named_attr.first.str();
329     auto attr = named_attr.second;
330     absl::string_view name(name_strref.data(), name_strref.size());
331     if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
332       // The name, device spec of a TF op or function are not stored as
333       // AttrValue inside NodeDef, but we model them using attribute inside
334       // MLIR. So we need to ignore them when going back to AttrValue here.
335       continue;
336     }
337     if (mangling_util::IsMangledAttributeName(name)) {
338       // In MLIR, attributes for functions requires dialect prefix. We need to
339       // remove TF dialect prefix before converting to AttrValue.
340       name = mangling_util::DemangleAttributeName(name);
341     }
342     AttrValue value;
343     switch (attr.getKind()) {
344       case mlir::StandardAttributes::SymbolRef: {
345         auto func_attr = attr.cast<mlir::FlatSymbolRefAttr>();
346         value.mutable_func()->set_name(std::string(func_attr.getValue()));
347         func_call_attrs[string(name)] = value;
348         continue;
349       }
350       case mlir::StandardAttributes::Bool:
351         TF_RETURN_IF_ERROR(
352             ConvertAttribute(attr.cast<mlir::BoolAttr>(), &value));
353         break;
354       case mlir::StandardAttributes::Integer:
355         TF_RETURN_IF_ERROR(
356             ConvertAttribute(attr.cast<mlir::IntegerAttr>(), &value));
357         break;
358       case mlir::StandardAttributes::Float:
359         TF_RETURN_IF_ERROR(
360             ConvertAttribute(attr.cast<mlir::FloatAttr>(), &value));
361         break;
362       case mlir::StandardAttributes::String:
363         TF_RETURN_IF_ERROR(
364             ConvertAttribute(attr.cast<mlir::StringAttr>(), &value));
365         break;
366       case mlir::StandardAttributes::Array:
367         TF_RETURN_IF_ERROR(
368             ConvertAttribute(attr.cast<mlir::ArrayAttr>(), &value));
369         break;
370       case mlir::StandardAttributes::DenseElements:
371       case mlir::StandardAttributes::OpaqueElements:
372         TF_RETURN_IF_ERROR(
373             ConvertAttribute(attr.cast<mlir::ElementsAttr>(), &value));
374         break;
375       case mlir::StandardAttributes::Type:
376         TF_RETURN_IF_ERROR(
377             ConvertAttribute(attr.cast<mlir::TypeAttr>(), &value));
378         break;
379       case mlir::StandardAttributes::Unit:
380         TF_RETURN_IF_ERROR(
381             ConvertAttribute(attr.cast<mlir::UnitAttr>(), &value));
382         break;
383       // AffineMap kind is not implemented.
384       case mlir::StandardAttributes::AffineMap:
385         return errors::Unimplemented("AffineMap attribute (needed for '",
386                                      name_strref, "') unimplemented");
387       default:
388         return errors::Unimplemented("Unhandled attribute kind for attribute '",
389                                      name_strref, '\'');
390     }
391     // According to the NodeDef proto definition, an attribute name from the
392     // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
393     // the attribute from MLIR, it is treated as an attribute from function
394     // calls.
395     std::vector<string> name_tokens =
396         absl::StrSplit(name, '.', absl::SkipEmpty());
397     TF_RET_CHECK(name_tokens.size() <= 2);
398     auto it = func_call_attrs.find(name_tokens[0]);
399     if (it == func_call_attrs.end()) {
400       (*values)[string(name)] = value;
401     } else {
402       (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
403     }
404   }
405   for (const auto& it : func_call_attrs) {
406     (*values)[it.first] = it.second;
407   }
408   return Status::OK();
409 }
410 
411 // Sets type attribute with the given name. If the attribute already exists with
412 // a different value, returns an error.
SetTypeAttribute(absl::string_view name,mlir::Type type,AttrValueMap * values)413 Status SetTypeAttribute(absl::string_view name, mlir::Type type,
414                         AttrValueMap* values) {
415   DataType dtype;
416   TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype));
417   if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
418   AttrValue value;
419   value.set_type(dtype);
420 
421   auto result = values->insert({string(name), value});
422   if (!result.second) {
423     DataType actual_dtype = result.first->second.type();
424     if (actual_dtype != dtype) {
425       return errors::InvalidArgument("Expected ", DataType_Name(dtype), " '",
426                                      name, "' attribute but found ",
427                                      DataType_Name(actual_dtype));
428     }
429   }
430   return Status::OK();
431 }
432 
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)433 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
434                          AttrValueMap* values) {
435   tensorflow::TensorShapeProto tshape;
436   AttrValue value;
437   if (shaped_type.hasRank()) {
438     for (auto dim : shaped_type.getShape()) tshape.add_dim()->set_size(dim);
439   } else {
440     tshape.set_unknown_rank(true);
441   }
442   *value.mutable_shape() = tshape;
443 
444   auto result = values->insert({string(name), value});
445   if (!result.second) {
446     // This should be extremely rare as it means we are adding the same
447     // attribute multiple times/have some redundancy in representing this
448     // attribute.
449     TensorShapeProto actual_shape = result.first->second.shape();
450     // Just check via string output as we shouldn't get here and if we do they
451     // should be trivially the same, else fail.
452     if (actual_shape.ShortDebugString() != tshape.ShortDebugString()) {
453       return errors::InvalidArgument("Expected ", tshape.ShortDebugString(),
454                                      " '", name, "' attribute but found ",
455                                      actual_shape.ShortDebugString());
456     }
457   }
458   return Status::OK();
459 }
460 
SetSizeAttribute(absl::string_view name,size_t size,AttrValueMap * values)461 Status SetSizeAttribute(absl::string_view name, size_t size,
462                         AttrValueMap* values) {
463   AttrValue value;
464   value.set_i(size);
465 
466   auto result = values->insert({string(name), value});
467   if (!result.second) {
468     // This should be extremely rare as it means we are adding the same
469     // attribute multiple times/have some redundancy in representing this
470     // attribute.
471     int64 actual_size = result.first->second.i();
472     // Just check via string output as we shouldn't get here and if we do they
473     // should be trivially the same, else fail.
474     if (actual_size != size)
475       return errors::InvalidArgument("Expected '", name, "' attribute to be ",
476                                      size, " but found ", actual_size);
477   }
478   return Status::OK();
479 }
480 
IsLegacyCallInstruction(mlir::Operation * inst)481 bool IsLegacyCallInstruction(mlir::Operation* inst) {
482   return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst) ||
483          inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
484 }
485 
486 }  // namespace tensorflow
487