1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Location.h" // TF:llvm-project
29 #include "mlir/IR/Operation.h" // TF:llvm-project
30 #include "mlir/IR/Value.h" // TF:llvm-project
31
StringRefToView(llvm::StringRef ref)32 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
33 return absl::string_view(ref.data(), ref.size());
34 }
35
StringViewToRef(absl::string_view view)36 static inline llvm::StringRef StringViewToRef(absl::string_view view) {
37 return llvm::StringRef(view.data(), view.size());
38 }
39
40 namespace tensorflow {
41
~OpOrArgNameMapper()42 OpOrArgNameMapper::~OpOrArgNameMapper() {}
43
GetUniqueName(llvm::StringRef prefix)44 llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) {
45 // Insert/find if prefix is unique.
46 auto prefix_it = name_to_count_.try_emplace(prefix, 0);
47 if (prefix_it.second && IsUnique(prefix)) {
48 // Name is unique, increment count and return string name backed by
49 // `name_to_count_`.
50 ++prefix_it.first->second;
51 return prefix_it.first->first();
52 }
53
54 // Add increasing number (count) to end of prefix until it is determined
55 // to be unique.
56 auto& val = prefix_it.first->second;
57 llvm::SmallString<64> probe_name(prefix);
58 while (true) {
59 probe_name.resize(prefix.size());
60 // TODO(jpienaar): Subtract one so that the initial suffix is 0 instead
61 // of 1.
62 // TODO(jpienaar): Switch to radix 36 and update tests.
63 llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10, /*Signed=*/false);
64 if (IsUnique(probe_name)) {
65 // Insert/find if prefix with appended number is unique.
66 auto probe_name_it = name_to_count_.try_emplace(probe_name, 1);
67 if (probe_name_it.second) {
68 // Name is unique, return string name backed by `name_to_count_`.
69 return probe_name_it.first->first();
70 }
71 }
72 }
73 }
74
GetUniqueName(OpOrVal op_or_val)75 llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val) {
76 auto& name = op_or_val_to_name_[op_or_val];
77 if (!name.empty()) return StringViewToRef(name);
78 // Update the value in the map with unique name.
79 llvm::StringRef ref = GetUniqueName(GetName(op_or_val));
80 name = StringRefToView(ref);
81 return ref;
82 }
83
GetUniqueNameView(OpOrVal op_or_val)84 absl::string_view OpOrArgNameMapper::GetUniqueNameView(OpOrVal op_or_val) {
85 auto& name = op_or_val_to_name_[op_or_val];
86 if (!name.empty()) return name;
87 // Update the value in the map with unique name.
88 name = StringRefToView(GetUniqueName(GetName(op_or_val)));
89 return name;
90 }
91
InitOpName(OpOrVal op_or_val,llvm::StringRef name)92 int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
93 auto it = name_to_count_.try_emplace(name, 0);
94 auto inserted = op_or_val_to_name_.try_emplace(
95 op_or_val, StringRefToView(it.first->first()));
96 (void)inserted;
97 // TODO(jpienaar): Debug cases where we expect this behavior.
98 // assert(inserted.second && "op_or_val already initialized");
99 return it.first->second++;
100 }
101
IsUnique(llvm::StringRef name)102 bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
103
104 namespace {
105 // Derives name from location.
GetNameFromLoc(mlir::Location loc)106 std::string GetNameFromLoc(mlir::Location loc) {
107 llvm::SmallVector<llvm::StringRef, 8> loc_names;
108 llvm::SmallVector<mlir::Location, 8> locs;
109 locs.push_back(loc);
110 bool names_is_nonempty = false;
111
112 while (!locs.empty()) {
113 mlir::Location curr_loc = locs.pop_back_val();
114
115 if (auto name_loc = curr_loc.dyn_cast<mlir::NameLoc>()) {
116 // Add name in NameLoc. For NameLoc we also account for names due to ops
117 // in functions where the op's name is first.
118 auto name = name_loc.getName().strref().split('@').first;
119 loc_names.push_back(name);
120 if (!name.empty()) names_is_nonempty = true;
121 continue;
122 } else if (auto call_loc = curr_loc.dyn_cast<mlir::CallSiteLoc>()) {
123 // Add name if CallSiteLoc's callee has a NameLoc (as should be the
124 // case if imported with DebugInfo).
125 if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>()) {
126 auto name = name_loc.getName().strref().split('@').first;
127 loc_names.push_back(name);
128 if (!name.empty()) names_is_nonempty = true;
129 continue;
130 }
131 } else if (auto fused_loc = curr_loc.dyn_cast<mlir::FusedLoc>()) {
132 // Push all locations in FusedLoc in reverse order, so locations are
133 // visited based on order in FusedLoc.
134 auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
135 locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
136 continue;
137 }
138
139 // Location is not a supported, so an empty StringRef is added.
140 loc_names.push_back(llvm::StringRef());
141 }
142
143 if (names_is_nonempty)
144 return llvm::join(loc_names.begin(), loc_names.end(), ";");
145
146 return "";
147 }
148 } // anonymous namespace
149
GetName(OpOrVal op_or_val)150 std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
151 if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
152 auto name_from_loc = GetNameFromLoc(op->getLoc());
153 if (!name_from_loc.empty()) return name_from_loc;
154 // If the location is none of the expected types, then simply use name
155 // generated using the op type.
156 return std::string(op->getName().getStringRef());
157 }
158 auto val = op_or_val.dyn_cast<mlir::Value>();
159 auto name_from_loc = GetNameFromLoc(val.getLoc());
160 if (!name_from_loc.empty()) return name_from_loc;
161 // If the location is none of the expected types, then simply use name
162 // generated using the op type. Follow TF convention and append the result
163 // index unless 0.
164 if (auto result = val.dyn_cast<mlir::OpResult>()) {
165 if (result.getResultNumber() > 0)
166 return llvm::formatv("{0}:{1}",
167 result.getOwner()->getName().getStringRef(),
168 result.getResultNumber());
169 return std::string(result.getOwner()->getName().getStringRef());
170 }
171 return "";
172 }
173
GetName(OpOrVal op_or_val)174 std::string OpOrArgStripNameMapper::GetName(OpOrVal op_or_val) {
175 return llvm::APInt(32, count_++).toString(/*Radix=*/36, /*Signed=*/false);
176 }
177
178 } // namespace tensorflow
179