• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Location.h"  // TF:llvm-project
29 #include "mlir/IR/Operation.h"  // TF:llvm-project
30 #include "mlir/IR/Value.h"  // TF:llvm-project
31 
StringRefToView(llvm::StringRef ref)32 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
33   return absl::string_view(ref.data(), ref.size());
34 }
35 
StringViewToRef(absl::string_view view)36 static inline llvm::StringRef StringViewToRef(absl::string_view view) {
37   return llvm::StringRef(view.data(), view.size());
38 }
39 
40 namespace tensorflow {
41 
~OpOrArgNameMapper()42 OpOrArgNameMapper::~OpOrArgNameMapper() {}
43 
GetUniqueName(llvm::StringRef prefix)44 llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) {
45   // Insert/find if prefix is unique.
46   auto prefix_it = name_to_count_.try_emplace(prefix, 0);
47   if (prefix_it.second && IsUnique(prefix)) {
48     // Name is unique, increment count and return string name backed by
49     // `name_to_count_`.
50     ++prefix_it.first->second;
51     return prefix_it.first->first();
52   }
53 
54   // Add increasing number (count) to end of prefix until it is determined
55   // to be unique.
56   auto& val = prefix_it.first->second;
57   llvm::SmallString<64> probe_name(prefix);
58   while (true) {
59     probe_name.resize(prefix.size());
60     // TODO(jpienaar): Subtract one so that the initial suffix is 0 instead
61     // of 1.
62     // TODO(jpienaar): Switch to radix 36 and update tests.
63     llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10, /*Signed=*/false);
64     if (IsUnique(probe_name)) {
65       // Insert/find if prefix with appended number is unique.
66       auto probe_name_it = name_to_count_.try_emplace(probe_name, 1);
67       if (probe_name_it.second) {
68         // Name is unique, return string name backed by `name_to_count_`.
69         return probe_name_it.first->first();
70       }
71     }
72   }
73 }
74 
GetUniqueName(OpOrVal op_or_val)75 llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val) {
76   auto& name = op_or_val_to_name_[op_or_val];
77   if (!name.empty()) return StringViewToRef(name);
78   // Update the value in the map with unique name.
79   llvm::StringRef ref = GetUniqueName(GetName(op_or_val));
80   name = StringRefToView(ref);
81   return ref;
82 }
83 
GetUniqueNameView(OpOrVal op_or_val)84 absl::string_view OpOrArgNameMapper::GetUniqueNameView(OpOrVal op_or_val) {
85   auto& name = op_or_val_to_name_[op_or_val];
86   if (!name.empty()) return name;
87   // Update the value in the map with unique name.
88   name = StringRefToView(GetUniqueName(GetName(op_or_val)));
89   return name;
90 }
91 
InitOpName(OpOrVal op_or_val,llvm::StringRef name)92 int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
93   auto it = name_to_count_.try_emplace(name, 0);
94   auto inserted = op_or_val_to_name_.try_emplace(
95       op_or_val, StringRefToView(it.first->first()));
96   (void)inserted;
97   // TODO(jpienaar): Debug cases where we expect this behavior.
98   // assert(inserted.second && "op_or_val already initialized");
99   return it.first->second++;
100 }
101 
IsUnique(llvm::StringRef name)102 bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
103 
104 namespace {
105 // Derives name from location.
GetNameFromLoc(mlir::Location loc)106 std::string GetNameFromLoc(mlir::Location loc) {
107   llvm::SmallVector<llvm::StringRef, 8> loc_names;
108   llvm::SmallVector<mlir::Location, 8> locs;
109   locs.push_back(loc);
110   bool names_is_nonempty = false;
111 
112   while (!locs.empty()) {
113     mlir::Location curr_loc = locs.pop_back_val();
114 
115     if (auto name_loc = curr_loc.dyn_cast<mlir::NameLoc>()) {
116       // Add name in NameLoc. For NameLoc we also account for names due to ops
117       // in functions where the op's name is first.
118       auto name = name_loc.getName().strref().split('@').first;
119       loc_names.push_back(name);
120       if (!name.empty()) names_is_nonempty = true;
121       continue;
122     } else if (auto call_loc = curr_loc.dyn_cast<mlir::CallSiteLoc>()) {
123       // Add name if CallSiteLoc's callee has a NameLoc (as should be the
124       // case if imported with DebugInfo).
125       if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>()) {
126         auto name = name_loc.getName().strref().split('@').first;
127         loc_names.push_back(name);
128         if (!name.empty()) names_is_nonempty = true;
129         continue;
130       }
131     } else if (auto fused_loc = curr_loc.dyn_cast<mlir::FusedLoc>()) {
132       // Push all locations in FusedLoc in reverse order, so locations are
133       // visited based on order in FusedLoc.
134       auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
135       locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
136       continue;
137     }
138 
139     // Location is not a supported, so an empty StringRef is added.
140     loc_names.push_back(llvm::StringRef());
141   }
142 
143   if (names_is_nonempty)
144     return llvm::join(loc_names.begin(), loc_names.end(), ";");
145 
146   return "";
147 }
148 }  // anonymous namespace
149 
GetName(OpOrVal op_or_val)150 std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
151   if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
152     auto name_from_loc = GetNameFromLoc(op->getLoc());
153     if (!name_from_loc.empty()) return name_from_loc;
154     // If the location is none of the expected types, then simply use name
155     // generated using the op type.
156     return std::string(op->getName().getStringRef());
157   }
158   auto val = op_or_val.dyn_cast<mlir::Value>();
159   auto name_from_loc = GetNameFromLoc(val.getLoc());
160   if (!name_from_loc.empty()) return name_from_loc;
161   // If the location is none of the expected types, then simply use name
162   // generated using the op type. Follow TF convention and append the result
163   // index unless 0.
164   if (auto result = val.dyn_cast<mlir::OpResult>()) {
165     if (result.getResultNumber() > 0)
166       return llvm::formatv("{0}:{1}",
167                            result.getOwner()->getName().getStringRef(),
168                            result.getResultNumber());
169     return std::string(result.getOwner()->getName().getStringRef());
170   }
171   return "";
172 }
173 
GetName(OpOrVal op_or_val)174 std::string OpOrArgStripNameMapper::GetName(OpOrVal op_or_val) {
175   return llvm::APInt(32, count_++).toString(/*Radix=*/36, /*Signed=*/false);
176 }
177 
178 }  // namespace tensorflow
179