1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26
27 namespace tflite {
28 namespace gpu {
29 namespace {
IsWordSymbol(char symbol)30 bool IsWordSymbol(char symbol) {
31 return absl::ascii_isalnum(symbol) || symbol == '_';
32 }
33
GetNextWord(const std::string & code,size_t first_position)34 std::string GetNextWord(const std::string& code, size_t first_position) {
35 size_t pos = first_position;
36 char t = code[pos];
37 while (IsWordSymbol(t)) {
38 pos++;
39 t = code[pos];
40 }
41 return code.substr(first_position, pos - first_position);
42 }
43
HasWord(const std::string & word,const std::string & text)44 bool HasWord(const std::string& word, const std::string& text) {
45 size_t pos = text.find(word);
46 while (pos != std::string::npos) {
47 char prev = pos == 0 ? '.' : text[pos - 1];
48 char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
49 if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
50 return true;
51 }
52 pos = text.find(word, pos + 1);
53 }
54 return false;
55 }
56
RenameArg(const std::vector<std::string> & object_names,const std::string & postfix,const std::string & arg_name)57 std::string RenameArg(const std::vector<std::string>& object_names,
58 const std::string& postfix, const std::string& arg_name) {
59 for (const auto& object_name : object_names) {
60 if (absl::StartsWith(arg_name, object_name) &&
61 arg_name.size() > object_name.size() &&
62 arg_name[object_name.size()] == '_') {
63 return object_name + postfix +
64 arg_name.substr(object_name.size(),
65 arg_name.size() - object_name.size());
66 }
67 }
68 return arg_name + postfix;
69 }
70
71 } // namespace
72
AddFloat(const std::string & name,float value)73 void Arguments::AddFloat(const std::string& name, float value) {
74 float_values_[name].value = value;
75 }
AddHalf(const std::string & name,half value)76 void Arguments::AddHalf(const std::string& name, half value) {
77 half_values_[name].value = value;
78 }
AddInt(const std::string & name,int value)79 void Arguments::AddInt(const std::string& name, int value) {
80 int_values_[name].value = value;
81 }
82
AddObjectRef(const std::string & name,AccessType access_type,GPUObjectDescriptorPtr && descriptor_ptr)83 void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
84 GPUObjectDescriptorPtr&& descriptor_ptr) {
85 descriptor_ptr->SetAccess(access_type);
86 object_refs_[name] = {std::move(descriptor_ptr)};
87 }
88
AddObject(const std::string & name,GPUObjectDescriptorPtr && descriptor_ptr)89 void Arguments::AddObject(const std::string& name,
90 GPUObjectDescriptorPtr&& descriptor_ptr) {
91 descriptor_ptr->SetAccess(AccessType::READ);
92 objects_[name] = {std::move(descriptor_ptr)};
93 }
94
RenameArgs(const std::string & postfix,std::string * code) const95 void Arguments::RenameArgs(const std::string& postfix,
96 std::string* code) const {
97 static constexpr char kArgsPrefix[] = "args.";
98 size_t next_position = code->find(kArgsPrefix);
99 while (next_position != std::string::npos) {
100 size_t arg_pos = next_position + strlen(kArgsPrefix);
101 std::string arg_name = GetNextWord(*code, arg_pos);
102 code->replace(arg_pos, arg_name.size(), arg_name + postfix);
103 next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
104 }
105 }
106
Merge(Arguments && args,const std::string & postfix,const std::vector<std::string> & exception_names)107 absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix,
108 const std::vector<std::string>& exception_names) {
109 std::vector<std::string> object_names;
110 object_names.reserve(args.object_refs_.size() + args.objects_.size());
111 for (auto& v : args.object_refs_) {
112 if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
113 exception_names.end()) {
114 continue;
115 }
116 object_names.push_back(v.first);
117 const std::string name = v.first + postfix;
118 if (object_refs_.find(name) != object_refs_.end()) {
119 return absl::InvalidArgumentError(
120 absl::StrCat("Object reference name collision. Name - ", name));
121 }
122 object_refs_[name] = {std::move(v.second)};
123 }
124 for (auto& v : args.objects_) {
125 if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
126 exception_names.end()) {
127 continue;
128 }
129 object_names.push_back(v.first);
130 const std::string name = v.first + postfix;
131 if (objects_.find(name) != objects_.end()) {
132 return absl::InvalidArgumentError(
133 absl::StrCat("Object name collision. Name - ", name));
134 }
135 objects_[name] = {std::move(v.second)};
136 }
137 for (const auto& v : args.int_values_) {
138 AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
139 }
140 for (const auto& v : args.float_values_) {
141 AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
142 }
143 for (const auto& v : args.half_values_) {
144 AddHalf(RenameArg(object_names, postfix, v.first), v.second.value);
145 }
146 return absl::OkStatus();
147 }
148
GetDescriptor(const std::string & name,GPUObjectDescriptor ** descriptor) const149 absl::Status Arguments::GetDescriptor(const std::string& name,
150 GPUObjectDescriptor** descriptor) const {
151 auto it_ref = object_refs_.find(name);
152 if (it_ref != object_refs_.end()) {
153 *descriptor = it_ref->second.get();
154 return absl::OkStatus();
155 }
156 auto it = objects_.find(name);
157 if (it != objects_.end()) {
158 *descriptor = it->second.get();
159 return absl::OkStatus();
160 }
161 return absl::NotFoundError(absl::StrCat("No GPU object with name - ", name));
162 }
163
ReleaseCPURepresentation()164 void Arguments::ReleaseCPURepresentation() {
165 for (auto& t : objects_) {
166 t.second->Release();
167 }
168 }
169
GetActiveArguments(const std::string & args_prefix,const std::string & code)170 void Arguments::GetActiveArguments(const std::string& args_prefix,
171 const std::string& code) {
172 for (auto& float_val : float_values_) {
173 float_val.second.active = HasWord(args_prefix + float_val.first, code);
174 }
175 for (auto& int_val : int_values_) {
176 int_val.second.active = HasWord(args_prefix + int_val.first, code);
177 }
178 for (auto& half_val : half_values_) {
179 half_val.second.active = HasWord(args_prefix + half_val.first, code);
180 }
181 }
182
GetReadTexturesCount(const GpuInfo & gpu_info) const183 int Arguments::GetReadTexturesCount(const GpuInfo& gpu_info) const {
184 int counter = 0;
185 for (auto& t : objects_) {
186 counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
187 }
188 for (auto& t : object_refs_) {
189 counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
190 }
191 return counter;
192 }
193
GetWriteTexturesCount(const GpuInfo & gpu_info) const194 int Arguments::GetWriteTexturesCount(const GpuInfo& gpu_info) const {
195 int counter = 0;
196 for (auto& t : objects_) {
197 counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
198 }
199 for (auto& t : object_refs_) {
200 counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
201 }
202 return counter;
203 }
204
SetStateValueForAllObjects(const std::string & key,const std::string & value)205 void Arguments::SetStateValueForAllObjects(const std::string& key,
206 const std::string& value) {
207 for (auto& obj : object_refs_) {
208 obj.second->SetStateVar(key, value);
209 }
210 for (auto& obj : objects_) {
211 obj.second->SetStateVar(key, value);
212 }
213 }
214
215 } // namespace gpu
216 } // namespace tflite
217