1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Location.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/Value.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/utils/name_utils.h"
32
StringRefToView(llvm::StringRef ref)33 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
34 return absl::string_view(ref.data(), ref.size());
35 }
36
StringViewToRef(absl::string_view view)37 static inline llvm::StringRef StringViewToRef(absl::string_view view) {
38 return llvm::StringRef(view.data(), view.size());
39 }
40
41 namespace tensorflow {
42
~OpOrArgNameMapper()43 OpOrArgNameMapper::~OpOrArgNameMapper() {}
44
GetUniqueName(llvm::StringRef prefix)45 llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) {
46 // Insert/find if prefix is unique.
47 auto prefix_it = name_to_count_.try_emplace(prefix, 0);
48 if (prefix_it.second && IsUnique(prefix)) {
49 // Name is unique, increment count and return string name backed by
50 // `name_to_count_`.
51 ++prefix_it.first->second;
52 return prefix_it.first->first();
53 }
54
55 // Add increasing number (count) to end of prefix until it is determined
56 // to be unique.
57 auto& val = prefix_it.first->second;
58 llvm::SmallString<64> probe_name(prefix);
59 probe_name.append(GetSuffixSeparator());
60 const int probe_prefix_size = probe_name.size();
61 while (true) {
62 probe_name.resize(probe_prefix_size);
63 // TODO(jpienaar): Subtract one so that the initial suffix is 0 instead
64 // of 1.
65 // TODO(jpienaar): Switch to radix 36 and update tests.
66 llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10, /*Signed=*/false);
67 if (IsUnique(probe_name)) {
68 // Insert/find if prefix with appended number is unique.
69 auto probe_name_it = name_to_count_.try_emplace(probe_name, 1);
70 if (probe_name_it.second) {
71 // Name is unique, return string name backed by `name_to_count_`.
72 return probe_name_it.first->first();
73 }
74 }
75 }
76 }
77
GetUniqueName(OpOrVal op_or_val)78 llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val) {
79 auto& name = op_or_val_to_name_[op_or_val];
80 if (!name.empty()) return StringViewToRef(name);
81 // Update the value in the map with unique name.
82 llvm::StringRef ref = GetUniqueName(GetName(op_or_val));
83 name = StringRefToView(ref);
84 return ref;
85 }
86
GetUniqueNameView(OpOrVal op_or_val)87 absl::string_view OpOrArgNameMapper::GetUniqueNameView(OpOrVal op_or_val) {
88 auto& name = op_or_val_to_name_[op_or_val];
89 if (!name.empty()) return name;
90 // Update the value in the map with unique name.
91 name = StringRefToView(GetUniqueName(GetName(op_or_val)));
92 return name;
93 }
94
InitOpName(OpOrVal op_or_val,llvm::StringRef name)95 int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
96 auto it = name_to_count_.try_emplace(name, 0);
97 auto inserted = op_or_val_to_name_.try_emplace(
98 op_or_val, StringRefToView(it.first->first()));
99 (void)inserted;
100 // TODO(jpienaar): Debug cases where we expect this behavior.
101 // assert(inserted.second && "op_or_val already initialized");
102 return it.first->second++;
103 }
104
IsUnique(llvm::StringRef name)105 bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
106
GetName(OpOrVal op_or_val)107 std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
108 if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
109 auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
110 if (!name_from_loc.empty()) return name_from_loc;
111 // If the location is none of the expected types, then simply use name
112 // generated using the op type.
113 return std::string(op->getName().getStringRef());
114 }
115 auto val = op_or_val.dyn_cast<mlir::Value>();
116 auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
117 if (!name_from_loc.empty()) return name_from_loc;
118 // If the location is none of the expected types, then simply use name
119 // generated using the op type. Follow TF convention and append the result
120 // index unless 0.
121 if (auto result = val.dyn_cast<mlir::OpResult>()) {
122 if (result.getResultNumber() > 0)
123 return llvm::formatv("{0}:{1}",
124 result.getOwner()->getName().getStringRef(),
125 result.getResultNumber());
126 return std::string(result.getOwner()->getName().getStringRef());
127 }
128 // Use the ASM syntax for BlockArgument
129 if (auto arg = val.dyn_cast<mlir::BlockArgument>()) {
130 return "arg" + std::to_string(arg.getArgNumber());
131 }
132 return "";
133 }
134
GetName(OpOrVal op_or_val)135 std::string OpOrArgStripNameMapper::GetName(OpOrVal op_or_val) {
136 return llvm::APInt(32, count_++).toString(/*Radix=*/36, /*Signed=*/false);
137 }
138
139 } // namespace tensorflow
140