1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/codegen/code_generator.h"
17
18 #include <cctype>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <utility>
22
23 #include "tensorflow_lite_support/codegen/utils.h"
24 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
25
26 namespace tflite {
27 namespace support {
28 namespace codegen {
29
30 namespace {
31
ResolveConflictedNamesByAddingIndex(std::vector<std::string> * names_ptr)32 void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
33 auto& names = *names_ptr;
34 std::unordered_map<std::string, int> indexes;
35 std::unordered_map<std::string, int> first_appearance;
36 for (int i = 0; i < names.size(); i++) {
37 if (indexes.find(names[i]) == indexes.end()) {
38 indexes[names[i]] = 1;
39 first_appearance[names[i]] = i;
40 } else {
41 indexes[names[i]] += 1;
42 names[i].append(std::to_string(indexes[names[i]]));
43 }
44 }
45 for (const auto& it : first_appearance) {
46 const auto& name = it.first;
47 const auto i = it.second;
48 if (indexes[name] > 1) {
49 names[i].append("1");
50 }
51 }
52 }
53
54 } // namespace
55
CodeGenerator()56 CodeGenerator::CodeGenerator() {}
57
VerifyMetadata(const ModelMetadata * metadata,ErrorReporter * err)58 bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
59 ErrorReporter* err) {
60 if (metadata == nullptr) {
61 err->Error("Loading nullptr is not allowed");
62 return false;
63 }
64 if (metadata->subgraph_metadata()->size() != 1) {
65 err->Error("Only exact 1 subgraph is supported");
66 return false;
67 }
68 return true;
69 }
70
71 std::pair<std::vector<std::string>, std::vector<std::string>>
NameInputsAndOutputs(const TensorMetadataList * inputs,const TensorMetadataList * outputs)72 CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
73 const TensorMetadataList* outputs) {
74 std::vector<std::string> input_names;
75 std::vector<std::string> output_names;
76 if (inputs != nullptr) {
77 input_names.reserve(inputs->size());
78 for (const auto* tensor : *inputs) {
79 input_names.push_back(NameTensor(*tensor, "input"));
80 }
81 }
82 if (outputs != nullptr) {
83 output_names.reserve(outputs->size());
84 for (const auto* tensor : *outputs) {
85 output_names.push_back(NameTensor(*tensor, "output"));
86 }
87 }
88 // Solve conflict
89 ResolveConflictedInputAndOutputNames(&input_names, &output_names);
90 return std::make_pair(input_names, output_names);
91 }
92
ConvertToValidName(const std::string & name)93 std::string CodeGenerator::ConvertToValidName(const std::string& name) {
94 // lowercase all
95 std::string result = name;
96 for (int i = 0; i < result.size(); i++) {
97 result[i] = std::tolower(result[i]);
98 }
99 // replace all non-alpha or non-numeric with underscores, except underscore
100 // itself
101 for (int i = 0; i < result.size(); i++) {
102 if (result[i] != '_' && !std::isalnum(result[i])) {
103 result[i] = '_';
104 }
105 }
106 // remove leading underscores
107 int leading_underscores = 0;
108 while (leading_underscores < result.size() &&
109 result[leading_underscores] == '_') {
110 leading_underscores++;
111 }
112 result.erase(0, leading_underscores);
113 if (result.empty()) {
114 return "";
115 }
116 // first char should be alpha
117 if (std::isalpha(result[0])) {
118 return result;
119 }
120 return "tensor_" + result;
121 }
122
NameTensor(const TensorMetadata & tensor,const std::string & default_name)123 std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
124 const std::string& default_name) {
125 if (tensor.name() != nullptr && tensor.name()->size() > 0) {
126 // TODO(b/141225157) Validate tensor name. It should be in lower case.
127 auto suggested_name = ConvertToValidName(tensor.name()->str());
128 if (!suggested_name.empty()) {
129 return suggested_name;
130 }
131 }
132 auto* content = tensor.content();
133 if (content == nullptr || content->content_properties() == nullptr) {
134 return default_name;
135 }
136 switch (content->content_properties_type()) {
137 case ContentProperties_ImageProperties:
138 return "image";
139 case ContentProperties_FeatureProperties:
140 return "feature";
141 default:
142 return default_name;
143 }
144 }
145
ResolveConflictedInputAndOutputNames(std::vector<std::string> * inputs,std::vector<std::string> * outputs)146 void CodeGenerator::ResolveConflictedInputAndOutputNames(
147 std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
148 std::unordered_set<std::string> io_conflict;
149 auto& input_names = *inputs;
150 auto& output_names = *outputs;
151 for (const auto& input : input_names) {
152 if (io_conflict.find(input) != io_conflict.end()) {
153 continue;
154 }
155 for (const auto& output : output_names) {
156 if (input == output) {
157 io_conflict.insert(input);
158 break;
159 }
160 }
161 }
162 for (int i = 0; i < input_names.size(); i++) {
163 if (io_conflict.find(input_names[i]) != io_conflict.end()) {
164 input_names[i] = "input_" + input_names[i];
165 }
166 }
167 for (int i = 0; i < output_names.size(); i++) {
168 if (io_conflict.find(output_names[i]) != io_conflict.end()) {
169 output_names[i] = "output_" + output_names[i];
170 }
171 }
172 // 2. Second, add index if input[i] == input[j]
173 ResolveConflictedNamesByAddingIndex(&input_names);
174 ResolveConflictedNamesByAddingIndex(&output_names);
175 }
176
177 } // namespace codegen
178 } // namespace support
179 } // namespace tflite
180