1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Identifier.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/IR/OperationSupport.h" // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/core/framework/attr_value.pb.h"
46 #include "tensorflow/core/framework/graph.pb.h"
47 #include "tensorflow/core/framework/graph_to_functiondef.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/tensor.pb.h"
52 #include "tensorflow/core/framework/tensor_shape.pb.h"
53 #include "tensorflow/core/framework/types.pb.h"
54 #include "tensorflow/core/graph/algorithm.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/platform/protobuf.h"
58
59 namespace tensorflow {
60 namespace {
61 // static TensorFlow op prefix set.
GlobalOpPrefixes()62 std::set<std::string>* GlobalOpPrefixes() {
63 static std::set<std::string>* global_op_prefixes = [] {
64 std::set<std::string>* result = new std::set<std::string>;
65 result->insert("tf.");
66 result->insert("tf_executor.");
67 return result;
68 }();
69 return global_op_prefixes;
70 }
71
72 // Converts a location to the debug information for the node def.
ConvertLocation(mlir::Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)73 Status ConvertLocation(mlir::Location inst_loc,
74 NodeDef::ExperimentalDebugInfo* debug_info) {
75 if (auto call_site = inst_loc.dyn_cast<mlir::CallSiteLoc>()) {
76 if (auto name_loc = call_site.getCallee().dyn_cast<mlir::NameLoc>()) {
77 debug_info->add_original_node_names(name_loc.getName().c_str());
78 }
79 } else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
80 auto locations = fused.getLocations();
81 if (locations.size() <= 1)
82 return errors::InvalidArgument("expected experimental debuf info.");
83 // skip the first one, which is the name of the node_def.
84 for (int i = 0, end = locations.size() - 1; i < end; ++i) {
85 TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
86 }
87 }
88 return Status::OK();
89 }
90
ConvertAttribute(const mlir::BoolAttr & attr,AttrValue * value)91 Status ConvertAttribute(const mlir::BoolAttr& attr, AttrValue* value) {
92 value->set_b(attr.getValue());
93 return Status::OK();
94 }
95
ConvertAttribute(const mlir::IntegerAttr & attr,AttrValue * value)96 Status ConvertAttribute(const mlir::IntegerAttr& attr, AttrValue* value) {
97 value->set_i(attr.getInt());
98 return Status::OK();
99 }
100
ConvertAttribute(const mlir::FloatAttr & attr,AttrValue * value)101 Status ConvertAttribute(const mlir::FloatAttr& attr, AttrValue* value) {
102 value->set_f(attr.getValueAsDouble());
103 return Status::OK();
104 }
105
ConvertAttribute(const mlir::ElementsAttr & attr,AttrValue * value)106 Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) {
107 return ConvertToTensorProto(attr, value->mutable_tensor());
108 }
109
ConvertAttribute(const mlir::TF::PlaceholderAttr & attr,AttrValue * value)110 Status ConvertAttribute(const mlir::TF::PlaceholderAttr& attr,
111 AttrValue* value) {
112 value->set_placeholder(attr.getValue().str());
113 return Status::OK();
114 }
115
ConvertAttribute(const mlir::TF::ShapeAttr & attr,AttrValue * value)116 Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) {
117 SetTensorShapeProto(attr, value->mutable_shape());
118 return Status::OK();
119 }
120
ConvertAttribute(const mlir::FlatSymbolRefAttr & attr,AttrValue * value)121 Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) {
122 value->mutable_func()->set_name(attr.getValue().str());
123 return Status::OK();
124 }
125
ConvertAttribute(const mlir::TF::FuncAttr & attr,bool remove_ref_type,AttrValue * value)126 Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type,
127 AttrValue* value) {
128 TF_RETURN_IF_ERROR(
129 ConvertAttribute(attr.getName().cast<mlir::FlatSymbolRefAttr>(), value));
130 TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(),
131 /*attrs_to_ignore=*/{}, remove_ref_type,
132 value->mutable_func()->mutable_attr()));
133 return Status::OK();
134 }
135
ConvertAttribute(const mlir::StringAttr & attr,AttrValue * value)136 Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) {
137 absl::string_view attr_value(attr.getValue().data(), attr.getValue().size());
138 switch (mangling_util::GetMangledKind(attr_value)) {
139 case mangling_util::MangledKind::kUnknown: {
140 value->set_s(std::string(attr_value));
141 return Status::OK();
142 }
143 case mangling_util::MangledKind::kDataType: {
144 DataType dtype;
145 TF_RETURN_IF_ERROR(mangling_util::DemangleDataType(attr_value, &dtype));
146 value->set_type(dtype);
147 return Status::OK();
148 }
149 case mangling_util::MangledKind::kTensorShape:
150 TF_RETURN_IF_ERROR(
151 mangling_util::DemangleShape(attr_value, value->mutable_shape()));
152 return Status::OK();
153 default:
154 return errors::Unimplemented("Mangled string couldn't be handled!");
155 }
156 return Status::OK();
157 }
158
ConvertAttribute(mlir::Type type,bool remove_ref_type,AttrValue * value)159 Status ConvertAttribute(mlir::Type type, bool remove_ref_type,
160 AttrValue* value) {
161 DataType dtype;
162 TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype));
163 if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype);
164 value->set_type(dtype);
165 return Status::OK();
166 }
167
ConvertAttribute(const mlir::TypeAttr & type,bool remove_ref_type,AttrValue * value)168 Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type,
169 AttrValue* value) {
170 return ConvertAttribute(type.getValue(), remove_ref_type, value);
171 }
172
ConvertAttribute(const mlir::UnitAttr & attr,AttrValue * value)173 Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) {
174 value->clear_value();
175 return Status::OK();
176 }
177
ConvertAttribute(const mlir::ArrayAttr & attr,bool remove_ref_type,AttrValue * value)178 Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type,
179 AttrValue* value) {
180 auto* list = value->mutable_list();
181 for (mlir::Attribute a : attr.getValue()) {
182 if (auto attr = a.dyn_cast<mlir::BoolAttr>()) {
183 list->add_b(attr.getValue());
184 } else if (auto attr = a.dyn_cast<mlir::IntegerAttr>()) {
185 list->add_i(attr.getInt());
186 } else if (auto attr = a.dyn_cast<mlir::FloatAttr>()) {
187 list->add_f(attr.getValueAsDouble());
188 } else if (auto attr = a.dyn_cast<mlir::StringAttr>()) {
189 AttrValue nested_value;
190 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value));
191 switch (nested_value.value_case()) {
192 case AttrValue::kS:
193 list->add_s(nested_value.s());
194 break;
195 case AttrValue::kType:
196 list->add_type(nested_value.type());
197 break;
198 case AttrValue::kShape:
199 *list->add_shape() = nested_value.shape();
200 break;
201 default:
202 return errors::Unimplemented("Unhandled nested attribute!");
203 }
204 } else if (auto attr = a.dyn_cast<mlir::ElementsAttr>()) {
205 TensorProto tensor;
206 TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor));
207 *list->add_tensor() = tensor;
208 } else if (auto attr = a.dyn_cast<mlir::FlatSymbolRefAttr>()) {
209 AttrValue attr_val;
210 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
211 *list->add_func() = attr_val.func();
212 } else if (auto attr = a.dyn_cast<mlir::TypeAttr>()) {
213 AttrValue attr_val;
214 // For type attributes, we only propagate the element type.
215 mlir::Type elt_type = attr.getValue();
216 if (auto shaped_type = elt_type.dyn_cast<mlir::ShapedType>()) {
217 elt_type = shaped_type.getElementType();
218 }
219 TF_RETURN_IF_ERROR(
220 ConvertAttribute(elt_type, remove_ref_type, &attr_val));
221 list->add_type(attr_val.type());
222 } else if (auto attr = a.dyn_cast<mlir::TF::ShapeAttr>()) {
223 AttrValue attr_val;
224 TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val));
225 *list->add_shape() = attr_val.shape();
226 } else {
227 return errors::Unimplemented("Unhandled attribute!");
228 }
229 }
230 return Status::OK();
231 }
232
233 // Returns true if the executor/control dialect op should map to Ref node in
234 // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand
235 // type. For executor dialect NextIteration it uses the 2nd operand type. For
236 // all others (Enter/Exit/Merge/Switch), if the output type is ref, they
237 // correspond to the Ref equivalent op in TF Graph.
IsRefTypeControlOp(mlir::Operation * op)238 static bool IsRefTypeControlOp(mlir::Operation* op) {
239 if (auto next_iter_sink =
240 llvm::dyn_cast<mlir::tf_executor::NextIterationSinkOp>(op))
241 return mlir::getElementTypeOrSelf(next_iter_sink.input().getType())
242 .isa<mlir::TF::TensorFlowRefType>();
243
244 auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef());
245 if (!op_name_or_status.ok()) return false;
246
247 auto op_name = op_name_or_status.ConsumeValueOrDie();
248 if (op_name.equals("NextIteration"))
249 return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
250 .isa<mlir::TF::TensorFlowRefType>();
251
252 if (op_name.equals("Enter") || op_name.equals("Exit") ||
253 op_name.equals("Switch") || op_name.equals("Merge")) {
254 return getElementTypeOrSelf(op->getResult(0).getType())
255 .isa<mlir::TF::TensorFlowRefType>();
256 }
257 return false;
258 }
259
260 } // anonymous namespace
261
GetTensorFlowOpName(llvm::StringRef op_name)262 StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef op_name) {
263 // When being converted to MLIR, some prefixes and suffixes are added to the
264 // operation types, and we have to remove them when converting the
265 // operations back to a graph:
266 // - "tf." or "tf_executor." : every operation type has this prefix.
267 // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
268 // don't need to consider ".source"/".Source" because the nodes with this
269 // suffix are skipped by the caller and will not be added to the graph.
270 auto prefixes = GlobalOpPrefixes();
271 if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
272 return op_name.consume_front(prefix);
273 })) {
274 return errors::FailedPrecondition("op node '", op_name.str(),
275 "' was not a TF op!");
276 }
277 // Control dialect NextIteration sink ends with ".sink" and Executor dialect
278 // NextIteration sink ends with ".Sink".
279 if (!op_name.consume_back(".sink")) op_name.consume_back(".Sink");
280 return op_name;
281 }
282
GetOperationNodeDef(mlir::Operation * inst,llvm::StringRef name)283 StatusOr<std::unique_ptr<NodeDef>> GetOperationNodeDef(
284 mlir::Operation* inst, llvm::StringRef name) {
285 auto node_def = absl::make_unique<NodeDef>();
286 // Note: we do not use NodeBuilder or NodeDefBuilder as that would require
287 // mapping back from the inputs to the input arguments.
288
289 llvm::SmallString<64> op_name;
290 if (IsLegacyCallInstruction(inst)) {
291 // The op_name is the name of the function.
292 op_name.append(
293 inst->getAttrOfType<mlir::SymbolRefAttr>("f").getLeafReference());
294 // Remove the attribute from the instruction as it is already converted to
295 // op_name.
296 auto attr_id = mlir::Identifier::get("f", inst->getContext());
297 inst->removeAttr(attr_id);
298 } else {
299 // Some control flow ops in TensorFlow Graph have their respective "Ref" ops
300 // as well. For example there is Enter and RefEnter op. RefEnter forwards
301 // the input ref buffer to output. However both Enter and RefEnter are
302 // mapped to tf_executor::EnterOp during import. Check if it is a Ref op to
303 // correctly map to the TensorFlow Graph op.
304 if (IsRefTypeControlOp(inst)) op_name = "Ref";
305 TF_ASSIGN_OR_RETURN(auto tf_name,
306 GetTensorFlowOpName(inst->getName().getStringRef()));
307 op_name.append(tf_name);
308 }
309
310 node_def->set_name(name.str());
311 node_def->set_op(std::string(op_name.str()));
312
313 // Update NodeDef constructed out of an MLIR Case/If/While op to map it to
314 // either TensorFlow StatelessX or X op depending on the additional attribute.
315 if (llvm::isa<mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp>(inst)) {
316 auto stateless = inst->getAttrOfType<mlir::BoolAttr>("is_stateless");
317 if (stateless && stateless.getValue())
318 *node_def->mutable_op() = "Stateless" + node_def->op();
319 }
320
321 // Add inputs to the NodeDef based on the number of operands. This is required
322 // as later when edges are added to the Node using Graph::AddEdge the
323 // associated NodeDef is not updated.
324 for (int i = 0, e = inst->getNumOperands(); i < e; ++i) {
325 node_def->add_input();
326 }
327 if (auto attr = inst->getAttrOfType<mlir::StringAttr>("device")) {
328 node_def->set_device(std::string(attr.getValue()));
329 }
330
331 // Add the node debug info.
332 TF_RETURN_IF_ERROR(ConvertLocation(
333 inst->getLoc(), node_def->mutable_experimental_debug_info()));
334
335 return node_def;
336 }
337
ConvertAttributes(const llvm::ArrayRef<mlir::NamedAttribute> attrs,const absl::flat_hash_set<absl::string_view> & attrs_to_ignore,bool remove_ref_type,AttrValueMap * values)338 Status ConvertAttributes(
339 const llvm::ArrayRef<mlir::NamedAttribute> attrs,
340 const absl::flat_hash_set<absl::string_view>& attrs_to_ignore,
341 bool remove_ref_type, AttrValueMap* values) {
342 AttrValueMap func_call_attrs;
343 for (const mlir::NamedAttribute& named_attr : attrs) {
344 auto name_strref = named_attr.first.str();
345 auto attr = named_attr.second;
346 absl::string_view name(name_strref.data(), name_strref.size());
347 if (name == "name" || name == "device" || attrs_to_ignore.contains(name)) {
348 // The name, device spec of a TF op or function are not stored as
349 // AttrValue inside NodeDef, but we model them using attribute inside
350 // MLIR. So we need to ignore them when going back to AttrValue here.
351 continue;
352 }
353 if (mangling_util::IsMangledAttributeName(name)) {
354 // In MLIR, attributes for functions requires dialect prefix. We need to
355 // remove TF dialect prefix before converting to AttrValue.
356 name = mangling_util::DemangleAttributeName(name);
357 }
358 AttrValue value;
359 if (auto symbol_ref = attr.dyn_cast<mlir::SymbolRefAttr>()) {
360 TF_RETURN_IF_ERROR(
361 ConvertAttribute(symbol_ref.cast<mlir::FlatSymbolRefAttr>(), &value));
362 func_call_attrs[string(name)] = value;
363 continue;
364 }
365 if (auto func_attr = attr.dyn_cast<mlir::TF::FuncAttr>()) {
366 TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value));
367 func_call_attrs[string(name)] = value;
368 continue;
369 }
370 if (attr.isa<mlir::AffineMapAttr>()) {
371 // AffineMapAttr is not implemented.
372 return errors::Unimplemented("AffineMap attribute (needed for '",
373 name_strref, "') unimplemented");
374 }
375 TF_RETURN_IF_ERROR(
376 llvm::TypeSwitch<mlir::Attribute, Status>(attr)
377 .Case<mlir::BoolAttr, mlir::IntegerAttr, mlir::FloatAttr,
378 mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr,
379 mlir::TF::ShapeAttr, mlir::TF::PlaceholderAttr>(
380 [&](auto derived_attr) {
381 return ConvertAttribute(derived_attr, &value);
382 })
383 .Case<mlir::ArrayAttr, mlir::TypeAttr>([&](auto derived_attr) {
384 return ConvertAttribute(derived_attr, remove_ref_type, &value);
385 })
386 .Default([&](mlir::Attribute) {
387 return errors::Unimplemented(
388 "Unhandled attribute kind for attribute '", name_strref,
389 '\'');
390 }));
391
392 // According to the NodeDef proto definition, an attribute name from the
393 // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in
394 // the attribute from MLIR, it is treated as an attribute from function
395 // calls.
396 std::vector<string> name_tokens =
397 absl::StrSplit(name, '.', absl::SkipEmpty());
398 TF_RET_CHECK(name_tokens.size() <= 2);
399 auto it = func_call_attrs.find(name_tokens[0]);
400 if (it == func_call_attrs.end()) {
401 (*values)[string(name)] = value;
402 } else {
403 (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = value;
404 }
405 }
406 for (const auto& it : func_call_attrs) {
407 (*values)[it.first] = it.second;
408 }
409 return Status::OK();
410 }
411
SetShapeAttribute(absl::string_view name,mlir::ShapedType shaped_type,AttrValueMap * values)412 Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type,
413 AttrValueMap* values) {
414 AttrValue value;
415 SetTensorShapeProto(shaped_type, value.mutable_shape());
416
417 auto result = values->insert({string(name), value});
418 if (!result.second) {
419 // This should be extremely rare as it means we are adding the same
420 // attribute multiple times/have some redundancy in representing this
421 // attribute.
422 TensorShapeProto actual_shape = result.first->second.shape();
423 // Just check via string output as we shouldn't get here and if we do they
424 // should be trivially the same, else fail.
425 std::string new_shape_string = value.shape().ShortDebugString();
426 if (actual_shape.ShortDebugString() != new_shape_string) {
427 return errors::InvalidArgument("Expected ", new_shape_string, " '", name,
428 "' attribute but found ",
429 actual_shape.ShortDebugString());
430 }
431 }
432 return Status::OK();
433 }
434
IsLegacyCallInstruction(mlir::Operation * inst)435 bool IsLegacyCallInstruction(mlir::Operation* inst) {
436 return llvm::dyn_cast<mlir::TF::LegacyCallOp>(inst);
437 }
438
AddTensorFlowOpPrefix(std::string prefix)439 Status AddTensorFlowOpPrefix(std::string prefix) {
440 GlobalOpPrefixes()->insert(prefix);
441 return Status::OK();
442 }
443
444 } // namespace tensorflow
445