• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Identifier.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/core/framework/attr_value.pb.h"
46 #include "tensorflow/core/framework/graph.pb.h"
47 #include "tensorflow/core/framework/graph_to_functiondef.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/tensor.pb.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"
53 #include "tensorflow/core/framework/types.pb.h"
54 #include "tensorflow/core/graph/algorithm.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 
59 namespace tensorflow {
60 namespace {
61 // static TensorFlow op prefix set.
GlobalOpPrefixes()62 std::set<std::string>* GlobalOpPrefixes() {
63   static std::set<std::string>* global_op_prefixes = [] {
64     std::set<std::string>* result = new std::set<std::string>;
65     result->insert("tf.");
66     result->insert("tf_executor.");
67     return result;
68   }();
69   return global_op_prefixes;
70 }
71 
72 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)73 Status ConvertLocation(mlir::Location inst_loc,
74                        NodeDef::ExperimentalDebugInfo* debug_info) {
75   if (auto call_site = inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
76     if (auto name_loc = call_site.getCallee().dyn_cast<mlir::NameLoc>()) {
77       debug_info->add_original_node_names(name_loc.getName().c_str());
78     }
79   } else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
80     auto locations = fused.getLocations();
81     if (locations.size() <= 1)
82       return errors::InvalidArgument("expected experimental debuf info.");
83     // skip the first one, which is the name of the node_def.
84     for (int i = 0, end = locations.size() - 1; i < end; ++i) {
85       TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
86     }
87   }
88   return Status::OK();
89 }
90 
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)91 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
92   value->set_b(attr.getValue());
93   return Status::OK();
94 }
95 
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)96 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
97   value->set_i(attr.getInt());
98   return Status::OK();
99 }
100 
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)101 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
102   value->set_f(attr.getValueAsDouble());
103   return Status::OK();
104 }
105 
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)106 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
107   return ConvertToTensorProto(attr, value->mutable_tensor());
108 }
109 
ConvertAttribute(const mlir::TF::PlaceholderAttr & attr,AttrValue * value)110 Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr,
111                         AttrValue* value) {
112   value->set_placeholder(attr.getValue().str());
113   return Status::OK();
114 }
115 
ConvertAttribute(const mlir::TF::ShapeAttr & attr,AttrValue * value)116 Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
117   SetTensorShapeProto(attr, value->mutable_shape());
118   return Status::OK();
119 }
120 
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)121 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
122   value->mutable_func()->set_name(attr.getValue().str());
123   return Status::OK();
124 }
125 
ConvertAttribute(const mlir::TF::FuncAttr & attr,bool remove_ref_type,AttrValue * value)126 Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type,
127                         AttrValue* value) {
128   TF_RETURN_IF_ERROR(
129       ConvertAttribute(attr.getName().cast<mlir::FlatSymbolRefAttr>(), value));
130   TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(),
131                                        /*attrs_to_ignore=*/{}, remove_ref_type,
132                                        value->mutable_func()->mutable_attr()));
133   return Status::OK();
134 }
135 
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)136 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
137   absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
138   switch (mangling_util::GetMangledKind(attr_value)) {
139     case mangling_util::MangledKind::kUnknown: {
140       value->set_s(std::string(attr_value));
141       return Status::OK();
142     }
143     case mangling_util::MangledKind::kDataType: {
144       DataType dtype;
145       TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
146       value->set_type(dtype);
147       return Status::OK();
148     }
149     case mangling_util::MangledKind::kTensorShape:
150       TF_RETURN_IF_ERROR(
151           mangling_util::DemangleShape(attr_value, value->mutable_shape()));
152       return Status::OK();
153     default:
154       return errors::Unimplemented("Mangled string couldn't be handled!");
155   }
156   return Status::OK();
157 }
158 
ConvertAttribute(mlir::Type type,bool remove_ref_type,AttrValue * value)159 Status ConvertAttribute(mlir::Type type, bool remove_ref_type,
160                         AttrValue* value) {
161   DataType dtype;
162   TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
163   if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
164   value->set_type(dtype);
165   return Status::OK();
166 }
167 
ConvertAttribute(const mlir::TypeAttr & type,bool remove_ref_type,AttrValue * value)168 Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type,
169                         AttrValue* value) {
170   return ConvertAttribute(type.getValue(), remove_ref_type, value);
171 }
172 
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)173 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
174   value->clear_value();
175   return Status::OK();
176 }
177 
ConvertAttribute(const mlir::ArrayAttr & attr,bool remove_ref_type,AttrValue * value)178 Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type,
179                         AttrValue* value) {
180   auto* list = value->mutable_list();
181   for (mlir::Attribute a : attr.getValue()) {
182     if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
183       list->add_b(attr.getValue());
184     } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
185       list->add_i(attr.getInt());
186     } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
187       list->add_f(attr.getValueAsDouble());
188     } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
189       AttrValue nested_value;
190       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
191       switch (nested_value.value_case()) {
192         case AttrValue::kS:
193           list->add_s(nested_value.s());
194           break;
195         case AttrValue::kType:
196           list->add_type(nested_value.type());
197           break;
198         case AttrValue::kShape:
199           *list->add_shape() = nested_value.shape();
200           break;
201         default:
202           return errors::Unimplemented("Unhandled nested attribute!");
203       }
204     } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
205       TensorProto tensor;
206       TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
207       *list->add_tensor() = tensor;
208     } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
209       AttrValue attr_val;
210       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
211       *list->add_func() = attr_val.func();
212     } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
213       AttrValue attr_val;
214       // For type attributes, we only propagate the element type.
215       mlir::Type elt_type = attr.getValue();
216       if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
217         elt_type = shaped_type.getElementType();
218       }
219       TF_RETURN_IF_ERROR(
220           ConvertAttribute(elt_type, remove_ref_type, &attr_val));
221       list->add_type(attr_val.type());
222     } else if (auto attr = a.dyn_cast<mlir::TF::ShapeAttr>()) {
223       AttrValue attr_val;
224       TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
225       *list->add_shape() = attr_val.shape();
226     } else {
227       return errors::Unimplemented("Unhandled attribute!");
228     }
229   }
230   return Status::OK();
231 }
232 
233 // Returns true if the executor/control dialect op should map to Ref node in
234 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
235 // type. For executor dialect NextIteration it uses the 2nd operand type. For
236 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
237 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)238 static bool IsRefTypeControlOp(mlir::Operation* op) {
239   if (auto next_iter_sink =
240           llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
241     return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
242         .isa<mlir::TF::TensorFlowRefType>();
243 
244   auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
245   if (!op_name_or_status.ok()) return false;
246 
247   auto op_name = op_name_or_status.ConsumeValueOrDie();
248   if (op_name.equals("NextIteration"))
249     return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
250         .isa<mlir::TF::TensorFlowRefType>();
251 
252   if (op_name.equals("Enter") || op_name.equals("Exit") ||
253       op_name.equals("Switch") || op_name.equals("Merge")) {
254     return getElementTypeOrSelf(op->getResult(0).getType())
255         .isa<mlir::TF::TensorFlowRefType>();
256   }
257   return false;
258 }
259 
260 }  // anonymous namespace
261 
GetTensorFlowOpName(llvm::StringRef op_name)262 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
263   // When being converted to MLIR, some prefixes and suffixes are added to the
264   // operation types, and we have to remove them when converting the
265   // operations back to a graph:
266   // - "tf." or "tf_executor." : every operation type has this prefix.
267   // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
268   // don't need to consider ".source"/".Source" because the nodes with this
269   // suffix are skipped by the caller and will not be added to the graph.
270   auto prefixes = GlobalOpPrefixes();
271   if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
272         return op_name.consume_front(prefix);
273       })) {
274     return errors::FailedPrecondition("op node '", op_name.str(),
275                                       "' was not a TF op!");
276   }
277   // Control dialect NextIteration sink ends with ".sink" and Executor dialect
278   // NextIteration sink ends with ".Sink".
279   if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
280   return op_name;
281 }
282 
GetOperationNodeDef(mlir::Operation * inst,llvm::StringRef name)283 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
284     mlir::Operation* inst, llvm::StringRef name) {
285   auto node_def = absl::make_unique<NodeDef>();
286   // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
287   // mapping back from the inputs to the input arguments.
288 
289   llvm::SmallString<64> op_name;
290   if (IsLegacyCallInstruction(inst)) {
291     // The op_name is the name of the function.
292     op_name.append(
293         inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
294     // Remove the attribute from the instruction as it is already converted to
295     // op_name.
296     auto attr_id = mlir::Identifier::get("f", inst->getContext());
297     inst->removeAttr(attr_id);
298   } else {
299     // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
300     // as well. For example there is Enter and RefEnter op. RefEnter forwards
301     // the input ref buffer to output. However both Enter and RefEnter are
302     // mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
303     // correctly map to the TensorFlow Graph op.
304     if (IsRefTypeControlOp(inst)) op_name = "Ref";
305     TF_ASSIGN_OR_RETURN(auto tf_name,
306                         GetTensorFlowOpName(inst->getName().getStringRef()));
307     op_name.append(tf_name);
308   }
309 
310   node_def->set_name(name.str());
311   node_def->set_op(std::string(op_name.str()));
312 
313   // Update NodeDef constructed out of an MLIR Case/If/While op to map it to
314   // either TensorFlow StatelessX or X op depending on the additional attribute.
315   if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) {
316     auto stateless = inst->getAttrOfType<mlir::BoolAttr>("is_stateless");
317     if (stateless && stateless.getValue())
318       *node_def->mutable_op() = "Stateless" + node_def->op();
319   }
320 
321   // Add inputs to the NodeDef based on the number of operands. This is required
322   // as later when edges are added to the Node using Graph::AddEdge the
323   // associated NodeDef is not updated.
324   for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
325     node_def->add_input();
326   }
327   if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
328     node_def->set_device(std::string(attr.getValue()));
329   }
330 
331   // Add the node debug info.
332   TF_RETURN_IF_ERROR(ConvertLocation(
333       inst->getLoc(), node_def->mutable_experimental_debug_info()));
334 
335   return node_def;
336 }
337 
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)338 Status ConvertAttributes(
339     const llvm::ArrayRef<mlir::NamedAttribute> attrs,
340     const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
341     bool remove_ref_type, AttrValueMap* values) {
342   AttrValueMap func_call_attrs;
343   for (const mlir::NamedAttribute& named_attr : attrs) {
344     auto name_strref = named_attr.first.str();
345     auto attr = named_attr.second;
346     absl::string_view name(name_strref.data(), name_strref.size());
347     if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
348       // The name, device spec of a TF op or function are not stored as
349       // AttrValue inside NodeDef, but we model them using attribute inside
350       // MLIR. So we need to ignore them when going back to AttrValue here.
351       continue;
352     }
353     if (mangling_util::IsMangledAttributeName(name)) {
354       // In MLIR, attributes for functions requires dialect prefix. We need to
355       // remove TF dialect prefix before converting to AttrValue.
356       name = mangling_util::DemangleAttributeName(name);
357     }
358     AttrValue value;
359     if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
360       TF_RETURN_IF_ERROR(
361           ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
362       func_call_attrs[string(name)] = value;
363       continue;
364     }
365     if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
366       TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value));
367       func_call_attrs[string(name)] = value;
368       continue;
369     }
370     if (attr.isa<mlir::AffineMapAttr>()) {
371       // AffineMapAttr is not implemented.
372       return errors::Unimplemented("AffineMap attribute (needed for '",
373                                    name_strref, "') unimplemented");
374     }
375     TF_RETURN_IF_ERROR(
376         llvm::TypeSwitch<mlir::Attribute, Status>(attr)
377             .Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
378                   mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr,
379                   mlir::TF::ShapeAttr, mlir::TF::PlaceholderAttr>(
380                 [&](auto derived_attr) {
381                   return ConvertAttribute(derived_attr, &value);
382                 })
383             .Case<mlir::ArrayAttr, mlir::TypeAttr>([&](auto derived_attr) {
384               return ConvertAttribute(derived_attr, remove_ref_type, &value);
385             })
386             .Default([&](mlir::Attribute) {
387               return errors::Unimplemented(
388                   "Unhandled attribute kind for attribute '", name_strref,
389                   '\'');
390             }));
391 
392     // According to the NodeDef proto definition, an attribute name from the
393     // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
394     // the attribute from MLIR, it is treated as an attribute from function
395     // calls.
396     std::vector<string> name_tokens =
397         absl::StrSplit(name, '.', absl::SkipEmpty());
398     TF_RET_CHECK(name_tokens.size() <= 2);
399     auto it = func_call_attrs.find(name_tokens[0]);
400     if (it == func_call_attrs.end()) {
401       (*values)[string(name)] = value;
402     } else {
403       (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
404     }
405   }
406   for (const auto& it : func_call_attrs) {
407     (*values)[it.first] = it.second;
408   }
409   return Status::OK();
410 }
411 
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)412 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
413                          AttrValueMap* values) {
414   AttrValue value;
415   SetTensorShapeProto(shaped_type, value.mutable_shape());
416 
417   auto result = values->insert({string(name), value});
418   if (!result.second) {
419     // This should be extremely rare as it means we are adding the same
420     // attribute multiple times/have some redundancy in representing this
421     // attribute.
422     TensorShapeProto actual_shape = result.first->second.shape();
423     // Just check via string output as we shouldn't get here and if we do they
424     // should be trivially the same, else fail.
425     std::string new_shape_string = value.shape().ShortDebugString();
426     if (actual_shape.ShortDebugString() != new_shape_string) {
427       return errors::InvalidArgument("Expected ", new_shape_string, " '", name,
428                                      "' attribute but found ",
429                                      actual_shape.ShortDebugString());
430     }
431   }
432   return Status::OK();
433 }
434 
IsLegacyCallInstruction(mlir::Operation * inst)435 bool IsLegacyCallInstruction(mlir::Operation* inst) {
436   return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
437 }
438 
AddTensorFlowOpPrefix(std::string prefix)439 Status AddTensorFlowOpPrefix(std::string prefix) {
440   GlobalOpPrefixes()->insert(prefix);
441   return Status::OK();
442 }
443 
444 }  // namespace tensorflow
445