1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/utils/name_utils.h"
17
18 #include <cctype>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "mlir/IR/Identifier.h" // from @llvm-project
24
25 namespace mlir {
26
27 namespace {
28 // Checks if a character is legal for a TensorFlow node name, with special
29 // handling if a character is at the beginning.
IsLegalChar(char c,bool first_char)30 bool IsLegalChar(char c, bool first_char) {
31 if (isalpha(c)) return true;
32 if (isdigit(c)) return true;
33 if (c == '.') return true;
34 if (c == '_') return true;
35
36 // First character of a node name can only be a letter, digit, dot or
37 // underscore.
38 if (first_char) return false;
39
40 if (c == '/') return true;
41 if (c == '-') return true;
42
43 return false;
44 }
45 } // anonymous namespace
46
LegalizeNodeName(std::string & name)47 void LegalizeNodeName(std::string& name) {
48 if (name.empty()) return;
49
50 if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
51
52 for (char& c : llvm::drop_begin(name, 1))
53 if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
54 }
55
GetNameFromLoc(Location loc)56 std::string GetNameFromLoc(Location loc) {
57 llvm::SmallVector<llvm::StringRef, 8> loc_names;
58 llvm::SmallVector<Location, 8> locs;
59 locs.push_back(loc);
60 bool names_is_nonempty = false;
61
62 while (!locs.empty()) {
63 Location curr_loc = locs.pop_back_val();
64
65 if (auto name_loc = curr_loc.dyn_cast<NameLoc>()) {
66 // Add name in NameLoc. For NameLoc we also account for names due to ops
67 // in functions where the op's name is first.
68 auto name = name_loc.getName().strref().split('@').first;
69 loc_names.push_back(name);
70 if (!name.empty()) names_is_nonempty = true;
71 continue;
72 } else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>()) {
73 // Use location of the Callee to generate the name.
74 locs.push_back(call_loc.getCallee());
75 continue;
76 } else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>()) {
77 // Push all locations in FusedLoc in reverse order, so locations are
78 // visited based on order in FusedLoc.
79 auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
80 locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
81 continue;
82 }
83
84 // Location is not a supported, so an empty StringRef is added.
85 loc_names.push_back(llvm::StringRef());
86 }
87
88 if (names_is_nonempty)
89 return llvm::join(loc_names.begin(), loc_names.end(), ";");
90
91 return "";
92 }
93
94 } // namespace mlir
95