• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/codegen/code_generator.h"
17 
18 #include <cctype>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "tensorflow_lite_support/codegen/utils.h"
24 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
25 
26 namespace tflite {
27 namespace support {
28 namespace codegen {
29 
30 namespace {
31 
ResolveConflictedNamesByAddingIndex(std::vector<std::string> * names_ptr)32 void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
33   auto& names = *names_ptr;
34   std::unordered_map<std::string, int> indexes;
35   std::unordered_map<std::string, int> first_appearance;
36   for (int i = 0; i < names.size(); i++) {
37     if (indexes.find(names[i]) == indexes.end()) {
38       indexes[names[i]] = 1;
39       first_appearance[names[i]] = i;
40     } else {
41       indexes[names[i]] += 1;
42       names[i].append(std::to_string(indexes[names[i]]));
43     }
44   }
45   for (const auto& it : first_appearance) {
46     const auto& name = it.first;
47     const auto i = it.second;
48     if (indexes[name] > 1) {
49       names[i].append("1");
50     }
51   }
52 }
53 
54 }  // namespace
55 
CodeGenerator()56 CodeGenerator::CodeGenerator() {}
57 
VerifyMetadata(const ModelMetadata * metadata,ErrorReporter * err)58 bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
59                                    ErrorReporter* err) {
60   if (metadata == nullptr) {
61     err->Error("Loading nullptr is not allowed");
62     return false;
63   }
64   if (metadata->subgraph_metadata()->size() != 1) {
65     err->Error("Only exact 1 subgraph is supported");
66     return false;
67   }
68   return true;
69 }
70 
71 std::pair<std::vector<std::string>, std::vector<std::string>>
NameInputsAndOutputs(const TensorMetadataList * inputs,const TensorMetadataList * outputs)72 CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
73                                     const TensorMetadataList* outputs) {
74   std::vector<std::string> input_names;
75   std::vector<std::string> output_names;
76   if (inputs != nullptr) {
77     input_names.reserve(inputs->size());
78     for (const auto* tensor : *inputs) {
79       input_names.push_back(NameTensor(*tensor, "input"));
80     }
81   }
82   if (outputs != nullptr) {
83     output_names.reserve(outputs->size());
84     for (const auto* tensor : *outputs) {
85       output_names.push_back(NameTensor(*tensor, "output"));
86     }
87   }
88   // Solve conflict
89   ResolveConflictedInputAndOutputNames(&input_names, &output_names);
90   return std::make_pair(input_names, output_names);
91 }
92 
ConvertToValidName(const std::string & name)93 std::string CodeGenerator::ConvertToValidName(const std::string& name) {
94   // lowercase all
95   std::string result = name;
96   for (int i = 0; i < result.size(); i++) {
97     result[i] = std::tolower(result[i]);
98   }
99   // replace all non-alpha or non-numeric with underscores, except underscore
100   // itself
101   for (int i = 0; i < result.size(); i++) {
102     if (result[i] != '_' && !std::isalnum(result[i])) {
103       result[i] = '_';
104     }
105   }
106   // remove leading underscores
107   int leading_underscores = 0;
108   while (leading_underscores < result.size() &&
109          result[leading_underscores] == '_') {
110     leading_underscores++;
111   }
112   result.erase(0, leading_underscores);
113   if (result.empty()) {
114     return "";
115   }
116   // first char should be alpha
117   if (std::isalpha(result[0])) {
118     return result;
119   }
120   return "tensor_" + result;
121 }
122 
NameTensor(const TensorMetadata & tensor,const std::string & default_name)123 std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
124                                       const std::string& default_name) {
125   if (tensor.name() != nullptr && tensor.name()->size() > 0) {
126     // TODO(b/141225157) Validate tensor name. It should be in lower case.
127     auto suggested_name = ConvertToValidName(tensor.name()->str());
128     if (!suggested_name.empty()) {
129       return suggested_name;
130     }
131   }
132   auto* content = tensor.content();
133   if (content == nullptr || content->content_properties() == nullptr) {
134     return default_name;
135   }
136   switch (content->content_properties_type()) {
137     case ContentProperties_ImageProperties:
138       return "image";
139     case ContentProperties_FeatureProperties:
140       return "feature";
141     default:
142       return default_name;
143   }
144 }
145 
ResolveConflictedInputAndOutputNames(std::vector<std::string> * inputs,std::vector<std::string> * outputs)146 void CodeGenerator::ResolveConflictedInputAndOutputNames(
147     std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
148   std::unordered_set<std::string> io_conflict;
149   auto& input_names = *inputs;
150   auto& output_names = *outputs;
151   for (const auto& input : input_names) {
152     if (io_conflict.find(input) != io_conflict.end()) {
153       continue;
154     }
155     for (const auto& output : output_names) {
156       if (input == output) {
157         io_conflict.insert(input);
158         break;
159       }
160     }
161   }
162   for (int i = 0; i < input_names.size(); i++) {
163     if (io_conflict.find(input_names[i]) != io_conflict.end()) {
164       input_names[i] = "input_" + input_names[i];
165     }
166   }
167   for (int i = 0; i < output_names.size(); i++) {
168     if (io_conflict.find(output_names[i]) != io_conflict.end()) {
169       output_names[i] = "output_" + output_names[i];
170     }
171   }
172   // 2. Second, add index if input[i] == input[j]
173   ResolveConflictedNamesByAddingIndex(&input_names);
174   ResolveConflictedNamesByAddingIndex(&output_names);
175 }
176 
177 }  // namespace codegen
178 }  // namespace support
179 }  // namespace tflite
180