• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 
27 namespace tflite {
28 namespace gpu {
29 namespace {
IsWordSymbol(char symbol)30 bool IsWordSymbol(char symbol) {
31   return absl::ascii_isalnum(symbol) || symbol == '_';
32 }
33 
GetNextWord(const std::string & code,size_t first_position)34 std::string GetNextWord(const std::string& code, size_t first_position) {
35   size_t pos = first_position;
36   char t = code[pos];
37   while (IsWordSymbol(t)) {
38     pos++;
39     t = code[pos];
40   }
41   return code.substr(first_position, pos - first_position);
42 }
43 
HasWord(const std::string & word,const std::string & text)44 bool HasWord(const std::string& word, const std::string& text) {
45   size_t pos = text.find(word);
46   while (pos != std::string::npos) {
47     char prev = pos == 0 ? '.' : text[pos - 1];
48     char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
49     if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
50       return true;
51     }
52     pos = text.find(word, pos + 1);
53   }
54   return false;
55 }
56 
RenameArg(const std::vector<std::string> & object_names,const std::string & postfix,const std::string & arg_name)57 std::string RenameArg(const std::vector<std::string>& object_names,
58                       const std::string& postfix, const std::string& arg_name) {
59   for (const auto& object_name : object_names) {
60     if (absl::StartsWith(arg_name, object_name) &&
61         arg_name.size() > object_name.size() &&
62         arg_name[object_name.size()] == '_') {
63       return object_name + postfix +
64              arg_name.substr(object_name.size(),
65                              arg_name.size() - object_name.size());
66     }
67   }
68   return arg_name + postfix;
69 }
70 
71 }  // namespace
72 
AddFloat(const std::string & name,float value)73 void Arguments::AddFloat(const std::string& name, float value) {
74   float_values_[name].value = value;
75 }
AddHalf(const std::string & name,half value)76 void Arguments::AddHalf(const std::string& name, half value) {
77   half_values_[name].value = value;
78 }
AddInt(const std::string & name,int value)79 void Arguments::AddInt(const std::string& name, int value) {
80   int_values_[name].value = value;
81 }
82 
AddObjectRef(const std::string & name,AccessType access_type,GPUObjectDescriptorPtr && descriptor_ptr)83 void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
84                              GPUObjectDescriptorPtr&& descriptor_ptr) {
85   descriptor_ptr->SetAccess(access_type);
86   object_refs_[name] = {std::move(descriptor_ptr)};
87 }
88 
AddObject(const std::string & name,GPUObjectDescriptorPtr && descriptor_ptr)89 void Arguments::AddObject(const std::string& name,
90                           GPUObjectDescriptorPtr&& descriptor_ptr) {
91   descriptor_ptr->SetAccess(AccessType::READ);
92   objects_[name] = {std::move(descriptor_ptr)};
93 }
94 
RenameArgs(const std::string & postfix,std::string * code) const95 void Arguments::RenameArgs(const std::string& postfix,
96                            std::string* code) const {
97   static constexpr char kArgsPrefix[] = "args.";
98   size_t next_position = code->find(kArgsPrefix);
99   while (next_position != std::string::npos) {
100     size_t arg_pos = next_position + strlen(kArgsPrefix);
101     std::string arg_name = GetNextWord(*code, arg_pos);
102     code->replace(arg_pos, arg_name.size(), arg_name + postfix);
103     next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
104   }
105 }
106 
Merge(Arguments && args,const std::string & postfix,const std::vector<std::string> & exception_names)107 absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix,
108                               const std::vector<std::string>& exception_names) {
109   std::vector<std::string> object_names;
110   object_names.reserve(args.object_refs_.size() + args.objects_.size());
111   for (auto& v : args.object_refs_) {
112     if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
113         exception_names.end()) {
114       continue;
115     }
116     object_names.push_back(v.first);
117     const std::string name = v.first + postfix;
118     if (object_refs_.find(name) != object_refs_.end()) {
119       return absl::InvalidArgumentError(
120           absl::StrCat("Object reference name collision. Name - ", name));
121     }
122     object_refs_[name] = {std::move(v.second)};
123   }
124   for (auto& v : args.objects_) {
125     if (std::find(exception_names.begin(), exception_names.end(), v.first) !=
126         exception_names.end()) {
127       continue;
128     }
129     object_names.push_back(v.first);
130     const std::string name = v.first + postfix;
131     if (objects_.find(name) != objects_.end()) {
132       return absl::InvalidArgumentError(
133           absl::StrCat("Object name collision. Name - ", name));
134     }
135     objects_[name] = {std::move(v.second)};
136   }
137   for (const auto& v : args.int_values_) {
138     AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
139   }
140   for (const auto& v : args.float_values_) {
141     AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
142   }
143   for (const auto& v : args.half_values_) {
144     AddHalf(RenameArg(object_names, postfix, v.first), v.second.value);
145   }
146   return absl::OkStatus();
147 }
148 
GetDescriptor(const std::string & name,GPUObjectDescriptor ** descriptor) const149 absl::Status Arguments::GetDescriptor(const std::string& name,
150                                       GPUObjectDescriptor** descriptor) const {
151   auto it_ref = object_refs_.find(name);
152   if (it_ref != object_refs_.end()) {
153     *descriptor = it_ref->second.get();
154     return absl::OkStatus();
155   }
156   auto it = objects_.find(name);
157   if (it != objects_.end()) {
158     *descriptor = it->second.get();
159     return absl::OkStatus();
160   }
161   return absl::NotFoundError(absl::StrCat("No GPU object with name - ", name));
162 }
163 
ReleaseCPURepresentation()164 void Arguments::ReleaseCPURepresentation() {
165   for (auto& t : objects_) {
166     t.second->Release();
167   }
168 }
169 
GetActiveArguments(const std::string & args_prefix,const std::string & code)170 void Arguments::GetActiveArguments(const std::string& args_prefix,
171                                    const std::string& code) {
172   for (auto& float_val : float_values_) {
173     float_val.second.active = HasWord(args_prefix + float_val.first, code);
174   }
175   for (auto& int_val : int_values_) {
176     int_val.second.active = HasWord(args_prefix + int_val.first, code);
177   }
178   for (auto& half_val : half_values_) {
179     half_val.second.active = HasWord(args_prefix + half_val.first, code);
180   }
181 }
182 
GetReadTexturesCount(const GpuInfo & gpu_info) const183 int Arguments::GetReadTexturesCount(const GpuInfo& gpu_info) const {
184   int counter = 0;
185   for (auto& t : objects_) {
186     counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
187   }
188   for (auto& t : object_refs_) {
189     counter += t.second->GetGPUResources(gpu_info).GetReadImagesCount();
190   }
191   return counter;
192 }
193 
GetWriteTexturesCount(const GpuInfo & gpu_info) const194 int Arguments::GetWriteTexturesCount(const GpuInfo& gpu_info) const {
195   int counter = 0;
196   for (auto& t : objects_) {
197     counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
198   }
199   for (auto& t : object_refs_) {
200     counter += t.second->GetGPUResources(gpu_info).GetWriteImagesCount();
201   }
202   return counter;
203 }
204 
SetStateValueForAllObjects(const std::string & key,const std::string & value)205 void Arguments::SetStateValueForAllObjects(const std::string& key,
206                                            const std::string& value) {
207   for (auto& obj : object_refs_) {
208     obj.second->SetStateVar(key, value);
209   }
210   for (auto& obj : objects_) {
211     obj.second->SetStateVar(key, value);
212   }
213 }
214 
215 }  // namespace gpu
216 }  // namespace tflite
217