• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Identifier.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/core/common_runtime/graph_constructor.h"
46 #include "tensorflow/core/framework/attr_value.pb.h"
47 #include "tensorflow/core/framework/graph.pb.h"
48 #include "tensorflow/core/framework/graph_to_functiondef.h"
49 #include "tensorflow/core/framework/node_def.pb.h"
50 #include "tensorflow/core/framework/node_def_util.h"
51 #include "tensorflow/core/framework/op.h"
52 #include "tensorflow/core/framework/tensor.pb.h"
53 #include "tensorflow/core/framework/tensor_shape.pb.h"
54 #include "tensorflow/core/framework/types.pb.h"
55 #include "tensorflow/core/graph/algorithm.h"
56 #include "tensorflow/core/graph/graph.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 
60 namespace tensorflow {
61 namespace {
62 // static TensorFlow op prefix set.
GlobalOpPrefixes()63 std::set<std::string>* GlobalOpPrefixes() {
64   static std::set<std::string>* global_op_prefixes = [] {
65     std::set<std::string>* result = new std::set<std::string>;
66     result->insert("tf.");
67     result->insert("tf_executor.");
68     return result;
69   }();
70   return global_op_prefixes;
71 }
72 
73 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)74 Status ConvertLocation(mlir::Location inst_loc,
75                        NodeDef::ExperimentalDebugInfo* debug_info) {
76   if (auto call_site = inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
77     if (auto name_loc = call_site.getCallee().dyn_cast<mlir::NameLoc>()) {
78       debug_info->add_original_node_names(name_loc.getName().c_str());
79     }
80   } else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
81     auto locations = fused.getLocations();
82     if (locations.size() <= 1)
83       return errors::InvalidArgument("expected experimental debuf info.");
84     // skip the first one, which is the name of the node_def.
85     for (int i = 0, end = locations.size() - 1; i < end; ++i) {
86       TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
87     }
88   }
89   return Status::OK();
90 }
91 
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)92 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
93   value->set_b(attr.getValue());
94   return Status::OK();
95 }
96 
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)97 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
98   value->set_i(attr.getInt());
99   return Status::OK();
100 }
101 
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)102 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
103   value->set_f(attr.getValueAsDouble());
104   return Status::OK();
105 }
106 
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)107 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
108   return ConvertToTensorProto(attr, value->mutable_tensor());
109 }
110 
ConvertAttribute(const mlir::TF::ShapeAttr & attr,AttrValue * value)111 Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
112   auto* shape = value->mutable_shape();
113   if (attr.hasRank()) {
114     for (auto dim_size : attr.getShape()) {
115       auto* dim = shape->add_dim();
116       dim->set_size(dim_size);
117     }
118   } else {
119     shape->set_unknown_rank(true);
120   }
121   return Status::OK();
122 }
123 
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)124 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
125   value->mutable_func()->set_name(attr.getValue().str());
126   return Status::OK();
127 }
128 
ConvertAttribute(const mlir::TF::FuncAttr & attr,bool remove_ref_type,AttrValue * value)129 Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type,
130                         AttrValue* value) {
131   TF_RETURN_IF_ERROR(
132       ConvertAttribute(attr.GetName().cast<mlir::FlatSymbolRefAttr>(), value));
133   TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(),
134                                        /*attrs_to_ignore=*/{}, remove_ref_type,
135                                        value->mutable_func()->mutable_attr()));
136   return Status::OK();
137 }
138 
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)139 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
140   absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
141   switch (mangling_util::GetMangledKind(attr_value)) {
142     case mangling_util::MangledKind::kUnknown: {
143       value->set_s(std::string(attr_value));
144       return Status::OK();
145     }
146     case mangling_util::MangledKind::kDataType: {
147       DataType dtype;
148       TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
149       value->set_type(dtype);
150       return Status::OK();
151     }
152     case mangling_util::MangledKind::kTensorShape:
153       TF_RETURN_IF_ERROR(
154           mangling_util::DemangleShape(attr_value, value->mutable_shape()));
155       return Status::OK();
156     default:
157       return errors::Unimplemented("Mangled string couldn't be handled!");
158   }
159   return Status::OK();
160 }
161 
ConvertAttribute(mlir::Type type,bool remove_ref_type,AttrValue * value)162 Status ConvertAttribute(mlir::Type type, bool remove_ref_type,
163                         AttrValue* value) {
164   DataType dtype;
165   TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
166   if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
167   value->set_type(dtype);
168   return Status::OK();
169 }
170 
ConvertAttribute(const mlir::TypeAttr & type,bool remove_ref_type,AttrValue * value)171 Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type,
172                         AttrValue* value) {
173   return ConvertAttribute(type.getValue(), remove_ref_type, value);
174 }
175 
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)176 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
177   value->clear_value();
178   return Status::OK();
179 }
180 
ConvertAttribute(const mlir::ArrayAttr & attr,bool remove_ref_type,AttrValue * value)181 Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type,
182                         AttrValue* value) {
183   auto* list = value->mutable_list();
184   for (mlir::Attribute a : attr.getValue()) {
185     if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
186       list->add_b(attr.getValue());
187     } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
188       list->add_i(attr.getInt());
189     } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
190       list->add_f(attr.getValueAsDouble());
191     } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
192       AttrValue nested_value;
193       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
194       switch (nested_value.value_case()) {
195         case AttrValue::kS:
196           list->add_s(nested_value.s());
197           break;
198         case AttrValue::kType:
199           list->add_type(nested_value.type());
200           break;
201         case AttrValue::kShape:
202           *list->add_shape() = nested_value.shape();
203           break;
204         default:
205           return errors::Unimplemented("Unhandled nested attribute!");
206       }
207     } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
208       TensorProto tensor;
209       TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
210       *list->add_tensor() = tensor;
211     } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
212       AttrValue attr_val;
213       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
214       *list->add_func() = attr_val.func();
215     } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
216       AttrValue attr_val;
217       // For type attributes, we only propagate the element type.
218       mlir::Type elt_type = attr.getValue();
219       if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
220         elt_type = shaped_type.getElementType();
221       }
222       TF_RETURN_IF_ERROR(
223           ConvertAttribute(elt_type, remove_ref_type, &attr_val));
224       list->add_type(attr_val.type());
225     } else if (auto attr = a.dyn_cast<mlir::TF::ShapeAttr>()) {
226       AttrValue attr_val;
227       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
228       *list->add_shape() = attr_val.shape();
229     } else {
230       return errors::Unimplemented("Unhandled attribute!");
231     }
232   }
233   return Status::OK();
234 }
235 
236 // Returns true if the executor/control dialect op should map to Ref node in
237 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
238 // type. For executor dialect NextIteration it uses the 2nd operand type. For
239 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
240 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)241 static bool IsRefTypeControlOp(mlir::Operation* op) {
242   if (auto next_iter_sink =
243           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
244     return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
245         .isa<mlir::TF::TensorFlowRefType>();
246 
247   auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
248   if (!op_name_or_status.ok()) return false;
249 
250   auto op_name = op_name_or_status.ConsumeValueOrDie();
251   if (op_name.equals("NextIteration"))
252     return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
253         .isa<mlir::TF::TensorFlowRefType>();
254 
255   if (op_name.equals("Enter") || op_name.equals("Exit") ||
256       op_name.equals("Switch") || op_name.equals("Merge")) {
257     return getElementTypeOrSelf(op->getResult(0).getType())
258         .isa<mlir::TF::TensorFlowRefType>();
259   }
260   return false;
261 }
262 
263 }  // anonymous namespace
264 
GetTensorFlowOpName(llvm::StringRef op_name)265 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
266   // When being converted to MLIR, some prefixes and suffixes are added to the
267   // operation types, and we have to remove them when converting the
268   // operations back to a graph:
269   // - "tf." or "tf_executor." : every operation type has this prefix.
270   // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
271   // don't need to consider ".source"/".Source" because the nodes with this
272   // suffix are skipped by the caller and will not be added to the graph.
273   auto prefixes = GlobalOpPrefixes();
274   if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
275         return op_name.consume_front(prefix);
276       })) {
277     return errors::FailedPrecondition("op node '", op_name.str(),
278                                       "' was not a TF op!");
279   }
280   // Control dialect NextIteration sink ends with ".sink" and Executor dialect
281   // NextIteration sink ends with ".Sink".
282   if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
283   return op_name;
284 }
285 
GetOperationNodeDef(mlir::Operation * inst,llvm::StringRef name)286 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
287     mlir::Operation* inst, llvm::StringRef name) {
288   auto node_def = absl::make_unique<NodeDef>();
289   // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
290   // mapping back from the inputs to the input arguments.
291 
292   llvm::SmallString<64> op_name;
293   if (IsLegacyCallInstruction(inst)) {
294     // The op_name is the name of the function.
295     op_name.append(
296         inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
297     // Remove the attribute from the instruction as it is already converted to
298     // op_name.
299     auto attr_id = mlir::Identifier::get("f", inst->getContext());
300     inst->removeAttr(attr_id);
301   } else {
302     // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
303     // as well. For example there is Enter and RefEnter op. RefEnter forwards
304     // the input ref buffer to output. However both Enter and RefEnter are
305     // mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
306     // correctly map to the TensorFlow Graph op.
307     if (IsRefTypeControlOp(inst)) op_name = "Ref";
308     TF_ASSIGN_OR_RETURN(auto tf_name,
309                         GetTensorFlowOpName(inst->getName().getStringRef()));
310     op_name.append(tf_name);
311   }
312 
313   node_def->set_name(name.str());
314   node_def->set_op(std::string(op_name.str()));
315 
316   // Update NodeDef constructed out of an MLIR Case/If/While op to map it to
317   // either TensorFlow StatelessX or X op depending on the additional attribute.
318   if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) {
319     auto stateless = inst->getAttrOfType<mlir::BoolAttr>("is_stateless");
320     if (stateless && stateless.getValue())
321       *node_def->mutable_op() = "Stateless" + node_def->op();
322   }
323 
324   // Add inputs to the NodeDef based on the number of operands. This is required
325   // as later when edges are added to the Node using Graph::AddEdge the
326   // associated NodeDef is not updated.
327   for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
328     node_def->add_input();
329   }
330   if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
331     node_def->set_device(std::string(attr.getValue()));
332   }
333 
334   // Add the node debug info.
335   TF_RETURN_IF_ERROR(ConvertLocation(
336       inst->getLoc(), node_def->mutable_experimental_debug_info()));
337 
338   return node_def;
339 }
340 
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)341 Status ConvertAttributes(
342     const llvm::ArrayRef<mlir::NamedAttribute> attrs,
343     const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
344     bool remove_ref_type, AttrValueMap* values) {
345   AttrValueMap func_call_attrs;
346   for (const mlir::NamedAttribute& named_attr : attrs) {
347     auto name_strref = named_attr.first.str();
348     auto attr = named_attr.second;
349     absl::string_view name(name_strref.data(), name_strref.size());
350     if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
351       // The name, device spec of a TF op or function are not stored as
352       // AttrValue inside NodeDef, but we model them using attribute inside
353       // MLIR. So we need to ignore them when going back to AttrValue here.
354       continue;
355     }
356     if (mangling_util::IsMangledAttributeName(name)) {
357       // In MLIR, attributes for functions requires dialect prefix. We need to
358       // remove TF dialect prefix before converting to AttrValue.
359       name = mangling_util::DemangleAttributeName(name);
360     }
361     AttrValue value;
362     if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
363       TF_RETURN_IF_ERROR(
364           ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
365       func_call_attrs[string(name)] = value;
366       continue;
367     }
368     if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
369       TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value));
370       func_call_attrs[string(name)] = value;
371       continue;
372     }
373     if (attr.isa<mlir::AffineMapAttr>()) {
374       // AffineMapAttr is not implemented.
375       return errors::Unimplemented("AffineMap attribute (needed for '",
376                                    name_strref, "') unimplemented");
377     }
378     TF_RETURN_IF_ERROR(
379         llvm::TypeSwitch<mlir::Attribute, Status>(attr)
380             .Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
381                   mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr,
382                   mlir::TF::ShapeAttr>([&](auto derived_attr) {
383               return ConvertAttribute(derived_attr, &value);
384             })
385             .Case<mlir::ArrayAttr, mlir::TypeAttr>([&](auto derived_attr) {
386               return ConvertAttribute(derived_attr, remove_ref_type, &value);
387             })
388             .Default([&](mlir::Attribute) {
389               return errors::Unimplemented(
390                   "Unhandled attribute kind for attribute '", name_strref,
391                   '\'');
392             }));
393 
394     // According to the NodeDef proto definition, an attribute name from the
395     // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
396     // the attribute from MLIR, it is treated as an attribute from function
397     // calls.
398     std::vector<string> name_tokens =
399         absl::StrSplit(name, '.', absl::SkipEmpty());
400     TF_RET_CHECK(name_tokens.size() <= 2);
401     auto it = func_call_attrs.find(name_tokens[0]);
402     if (it == func_call_attrs.end()) {
403       (*values)[string(name)] = value;
404     } else {
405       (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
406     }
407   }
408   for (const auto& it : func_call_attrs) {
409     (*values)[it.first] = it.second;
410   }
411   return Status::OK();
412 }
413 
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)414 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
415                          AttrValueMap* values) {
416   tensorflow::TensorShapeProto tshape;
417   AttrValue value;
418   if (shaped_type.hasRank()) {
419     for (auto dim : shaped_type.getShape()) tshape.add_dim()->set_size(dim);
420   } else {
421     tshape.set_unknown_rank(true);
422   }
423   *value.mutable_shape() = tshape;
424 
425   auto result = values->insert({string(name), value});
426   if (!result.second) {
427     // This should be extremely rare as it means we are adding the same
428     // attribute multiple times/have some redundancy in representing this
429     // attribute.
430     TensorShapeProto actual_shape = result.first->second.shape();
431     // Just check via string output as we shouldn't get here and if we do they
432     // should be trivially the same, else fail.
433     if (actual_shape.ShortDebugString() != tshape.ShortDebugString()) {
434       return errors::InvalidArgument("Expected ", tshape.ShortDebugString(),
435                                      " '", name, "' attribute but found ",
436                                      actual_shape.ShortDebugString());
437     }
438   }
439   return Status::OK();
440 }
441 
IsLegacyCallInstruction(mlir::Operation * inst)442 bool IsLegacyCallInstruction(mlir::Operation* inst) {
443   return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
444 }
445 
AddTensorFlowOpPrefix(std::string prefix)446 Status AddTensorFlowOpPrefix(std::string prefix) {
447   GlobalOpPrefixes()->insert(prefix);
448   return Status::OK();
449 }
450 
451 }  // namespace tensorflow
452