1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
17
18 #include "absl/strings/ascii.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22
23 namespace tflite {
24 namespace gpu {
25 namespace {
IsWordSymbol(char symbol)26 bool IsWordSymbol(char symbol) {
27 return absl::ascii_isalnum(symbol) || symbol == '_';
28 }
29
GetNextWord(const std::string & code,size_t first_position)30 std::string GetNextWord(const std::string& code, size_t first_position) {
31 size_t pos = first_position;
32 char t = code[pos];
33 while (IsWordSymbol(t)) {
34 pos++;
35 t = code[pos];
36 }
37 return code.substr(first_position, pos - first_position);
38 }
39
HasWord(const std::string & word,const std::string & text)40 bool HasWord(const std::string& word, const std::string& text) {
41 size_t pos = text.find(word);
42 while (pos != std::string::npos) {
43 char prev = pos == 0 ? '.' : text[pos - 1];
44 char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
45 if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
46 return true;
47 }
48 pos = text.find(word, pos + 1);
49 }
50 return false;
51 }
52
RenameArg(const std::vector<std::string> & object_names,const std::string & postfix,const std::string & arg_name)53 std::string RenameArg(const std::vector<std::string>& object_names,
54 const std::string& postfix, const std::string& arg_name) {
55 for (const auto& object_name : object_names) {
56 if (absl::StartsWith(arg_name, object_name) &&
57 arg_name.size() > object_name.size() &&
58 arg_name[object_name.size()] == '_') {
59 return object_name + postfix +
60 arg_name.substr(object_name.size(),
61 arg_name.size() - object_name.size());
62 }
63 }
64 return arg_name + postfix;
65 }
66
67 } // namespace
68
AddFloat(const std::string & name,float value)69 void Arguments::AddFloat(const std::string& name, float value) {
70 float_values_[name].value = value;
71 }
AddHalf(const std::string & name,half value)72 void Arguments::AddHalf(const std::string& name, half value) {
73 half_values_[name].value = value;
74 }
AddInt(const std::string & name,int value)75 void Arguments::AddInt(const std::string& name, int value) {
76 int_values_[name].value = value;
77 }
78
AddObjectRef(const std::string & name,AccessType access_type,GPUObjectDescriptorPtr && descriptor_ptr)79 void Arguments::AddObjectRef(const std::string& name, AccessType access_type,
80 GPUObjectDescriptorPtr&& descriptor_ptr) {
81 descriptor_ptr->SetAccess(access_type);
82 object_refs_[name] = {std::move(descriptor_ptr)};
83 }
84
AddObject(const std::string & name,GPUObjectDescriptorPtr && descriptor_ptr)85 void Arguments::AddObject(const std::string& name,
86 GPUObjectDescriptorPtr&& descriptor_ptr) {
87 descriptor_ptr->SetAccess(AccessType::READ);
88 objects_[name] = {std::move(descriptor_ptr)};
89 }
90
RenameArgs(const std::string & postfix,std::string * code) const91 void Arguments::RenameArgs(const std::string& postfix,
92 std::string* code) const {
93 static constexpr char kArgsPrefix[] = "args.";
94 size_t next_position = code->find(kArgsPrefix);
95 while (next_position != std::string::npos) {
96 size_t arg_pos = next_position + strlen(kArgsPrefix);
97 std::string arg_name = GetNextWord(*code, arg_pos);
98 code->replace(arg_pos, arg_name.size(), arg_name + postfix);
99 next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
100 }
101 }
102
Merge(Arguments && args,const std::string & postfix)103 absl::Status Arguments::Merge(Arguments&& args, const std::string& postfix) {
104 std::vector<std::string> object_names;
105 object_names.reserve(args.object_refs_.size() + args.objects_.size());
106 for (auto& v : args.object_refs_) {
107 object_names.push_back(v.first);
108 const std::string name = v.first + postfix;
109 if (object_refs_.find(name) != object_refs_.end()) {
110 return absl::InvalidArgumentError(
111 absl::StrCat("Object reference name collision. Name - ", name));
112 }
113 object_refs_[name] = {std::move(v.second)};
114 }
115 for (auto& v : args.objects_) {
116 object_names.push_back(v.first);
117 const std::string name = v.first + postfix;
118 if (objects_.find(name) != objects_.end()) {
119 return absl::InvalidArgumentError(
120 absl::StrCat("Object name collision. Name - ", name));
121 }
122 objects_[name] = {std::move(v.second)};
123 }
124 for (const auto& v : args.int_values_) {
125 AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
126 }
127 for (const auto& v : args.float_values_) {
128 AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
129 }
130 for (const auto& v : args.half_values_) {
131 AddHalf(RenameArg(object_names, postfix, v.first), v.second.value);
132 }
133 return absl::OkStatus();
134 }
135
ReleaseCPURepresentation()136 void Arguments::ReleaseCPURepresentation() {
137 for (auto& t : objects_) {
138 t.second->Release();
139 }
140 }
141
GetActiveArguments(const std::string & args_prefix,const std::string & code)142 void Arguments::GetActiveArguments(const std::string& args_prefix,
143 const std::string& code) {
144 for (auto& float_val : float_values_) {
145 float_val.second.active = HasWord(args_prefix + float_val.first, code);
146 }
147 for (auto& int_val : int_values_) {
148 int_val.second.active = HasWord(args_prefix + int_val.first, code);
149 }
150 for (auto& half_val : half_values_) {
151 half_val.second.active = HasWord(args_prefix + half_val.first, code);
152 }
153 }
154
155 } // namespace gpu
156 } // namespace tflite
157