• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallString.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/IR/Location.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/Value.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/utils/name_utils.h"
32 
StringRefToView(llvm::StringRef ref)33 static inline absl::string_view StringRefToView(llvm::StringRef ref) {
34   return absl::string_view(ref.data(), ref.size());
35 }
36 
StringViewToRef(absl::string_view view)37 static inline llvm::StringRef StringViewToRef(absl::string_view view) {
38   return llvm::StringRef(view.data(), view.size());
39 }
40 
41 namespace tensorflow {
42 
~OpOrArgNameMapper()43 OpOrArgNameMapper::~OpOrArgNameMapper() {}
44 
GetUniqueName(llvm::StringRef prefix)45 llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) {
46   // Insert/find if prefix is unique.
47   auto prefix_it = name_to_count_.try_emplace(prefix, 0);
48   if (prefix_it.second && IsUnique(prefix)) {
49     // Name is unique, increment count and return string name backed by
50     // `name_to_count_`.
51     ++prefix_it.first->second;
52     return prefix_it.first->first();
53   }
54 
55   // Add increasing number (count) to end of prefix until it is determined
56   // to be unique.
57   auto& val = prefix_it.first->second;
58   llvm::SmallString<64> probe_name(prefix);
59   probe_name.append(GetSuffixSeparator());
60   const int probe_prefix_size = probe_name.size();
61   while (true) {
62     probe_name.resize(probe_prefix_size);
63     // TODO(jpienaar): Subtract one so that the initial suffix is 0 instead
64     // of 1.
65     // TODO(jpienaar): Switch to radix 36 and update tests.
66     llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10, /*Signed=*/false);
67     if (IsUnique(probe_name)) {
68       // Insert/find if prefix with appended number is unique.
69       auto probe_name_it = name_to_count_.try_emplace(probe_name, 1);
70       if (probe_name_it.second) {
71         // Name is unique, return string name backed by `name_to_count_`.
72         return probe_name_it.first->first();
73       }
74     }
75   }
76 }
77 
GetUniqueName(OpOrVal op_or_val)78 llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val) {
79   auto& name = op_or_val_to_name_[op_or_val];
80   if (!name.empty()) return StringViewToRef(name);
81   // Update the value in the map with unique name.
82   llvm::StringRef ref = GetUniqueName(GetName(op_or_val));
83   name = StringRefToView(ref);
84   return ref;
85 }
86 
GetUniqueNameView(OpOrVal op_or_val)87 absl::string_view OpOrArgNameMapper::GetUniqueNameView(OpOrVal op_or_val) {
88   auto& name = op_or_val_to_name_[op_or_val];
89   if (!name.empty()) return name;
90   // Update the value in the map with unique name.
91   name = StringRefToView(GetUniqueName(GetName(op_or_val)));
92   return name;
93 }
94 
InitOpName(OpOrVal op_or_val,llvm::StringRef name)95 int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
96   auto it = name_to_count_.try_emplace(name, 0);
97   auto inserted = op_or_val_to_name_.try_emplace(
98       op_or_val, StringRefToView(it.first->first()));
99   (void)inserted;
100   // TODO(jpienaar): Debug cases where we expect this behavior.
101   // assert(inserted.second && "op_or_val already initialized");
102   return it.first->second++;
103 }
104 
IsUnique(llvm::StringRef name)105 bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
106 
GetName(OpOrVal op_or_val)107 std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
108   if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
109     auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
110     if (!name_from_loc.empty()) return name_from_loc;
111     // If the location is none of the expected types, then simply use name
112     // generated using the op type.
113     return std::string(op->getName().getStringRef());
114   }
115   auto val = op_or_val.dyn_cast<mlir::Value>();
116   auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
117   if (!name_from_loc.empty()) return name_from_loc;
118   // If the location is none of the expected types, then simply use name
119   // generated using the op type. Follow TF convention and append the result
120   // index unless 0.
121   if (auto result = val.dyn_cast<mlir::OpResult>()) {
122     if (result.getResultNumber() > 0)
123       return llvm::formatv("{0}:{1}",
124                            result.getOwner()->getName().getStringRef(),
125                            result.getResultNumber());
126     return std::string(result.getOwner()->getName().getStringRef());
127   }
128   // Use the ASM syntax for BlockArgument
129   if (auto arg = val.dyn_cast<mlir::BlockArgument>()) {
130     return "arg" + std::to_string(arg.getArgNumber());
131   }
132   return "";
133 }
134 
GetName(OpOrVal op_or_val)135 std::string OpOrArgStripNameMapper::GetName(OpOrVal op_or_val) {
136   return llvm::APInt(32, count_++).toString(/*Radix=*/36, /*Signed=*/false);
137 }
138 
139 }  // namespace tensorflow
140