1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/task-context.h"
18
19 #include <stdlib.h>
20
21 #include <string>
22
23 #include "util/base/integral_types.h"
24 #include "util/base/logging.h"
25 #include "util/strings/numbers.h"
26
27 namespace libtextclassifier {
28 namespace nlp_core {
29
30 namespace {
ParseInt32WithDefault(const std::string & s,int32 defval)31 int32 ParseInt32WithDefault(const std::string &s, int32 defval) {
32 int32 value = defval;
33 return ParseInt32(s.c_str(), &value) ? value : defval;
34 }
35
ParseInt64WithDefault(const std::string & s,int64 defval)36 int64 ParseInt64WithDefault(const std::string &s, int64 defval) {
37 int64 value = defval;
38 return ParseInt64(s.c_str(), &value) ? value : defval;
39 }
40
ParseDoubleWithDefault(const std::string & s,double defval)41 double ParseDoubleWithDefault(const std::string &s, double defval) {
42 double value = defval;
43 return ParseDouble(s.c_str(), &value) ? value : defval;
44 }
45 } // namespace
46
GetInput(const std::string & name)47 TaskInput *TaskContext::GetInput(const std::string &name) {
48 // Return existing input if it exists.
49 for (int i = 0; i < spec_.input_size(); ++i) {
50 if (spec_.input(i).name() == name) return spec_.mutable_input(i);
51 }
52
53 // Create new input.
54 TaskInput *input = spec_.add_input();
55 input->set_name(name);
56 return input;
57 }
58
GetInput(const std::string & name,const std::string & file_format,const std::string & record_format)59 TaskInput *TaskContext::GetInput(const std::string &name,
60 const std::string &file_format,
61 const std::string &record_format) {
62 TaskInput *input = GetInput(name);
63 if (!file_format.empty()) {
64 bool found = false;
65 for (int i = 0; i < input->file_format_size(); ++i) {
66 if (input->file_format(i) == file_format) found = true;
67 }
68 if (!found) input->add_file_format(file_format);
69 }
70 if (!record_format.empty()) {
71 bool found = false;
72 for (int i = 0; i < input->record_format_size(); ++i) {
73 if (input->record_format(i) == record_format) found = true;
74 }
75 if (!found) input->add_record_format(record_format);
76 }
77 return input;
78 }
79
SetParameter(const std::string & name,const std::string & value)80 void TaskContext::SetParameter(const std::string &name,
81 const std::string &value) {
82 TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")";
83
84 // If the parameter already exists update the value.
85 for (int i = 0; i < spec_.parameter_size(); ++i) {
86 if (spec_.parameter(i).name() == name) {
87 spec_.mutable_parameter(i)->set_value(value);
88 return;
89 }
90 }
91
92 // Add new parameter.
93 TaskSpec::Parameter *param = spec_.add_parameter();
94 param->set_name(name);
95 param->set_value(value);
96 }
97
GetParameter(const std::string & name) const98 std::string TaskContext::GetParameter(const std::string &name) const {
99 // First try to find parameter in task specification.
100 for (int i = 0; i < spec_.parameter_size(); ++i) {
101 if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
102 }
103
104 // Parameter not found, return empty std::string.
105 return "";
106 }
107
GetIntParameter(const std::string & name) const108 int TaskContext::GetIntParameter(const std::string &name) const {
109 std::string value = GetParameter(name);
110 return ParseInt32WithDefault(value, 0);
111 }
112
GetInt64Parameter(const std::string & name) const113 int64 TaskContext::GetInt64Parameter(const std::string &name) const {
114 std::string value = GetParameter(name);
115 return ParseInt64WithDefault(value, 0);
116 }
117
GetBoolParameter(const std::string & name) const118 bool TaskContext::GetBoolParameter(const std::string &name) const {
119 std::string value = GetParameter(name);
120 return value == "true";
121 }
122
GetFloatParameter(const std::string & name) const123 double TaskContext::GetFloatParameter(const std::string &name) const {
124 std::string value = GetParameter(name);
125 return ParseDoubleWithDefault(value, 0.0);
126 }
127
Get(const std::string & name,const char * defval) const128 std::string TaskContext::Get(const std::string &name,
129 const char *defval) const {
130 // First try to find parameter in task specification.
131 for (int i = 0; i < spec_.parameter_size(); ++i) {
132 if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
133 }
134
135 // Parameter not found, return default value.
136 return defval;
137 }
138
Get(const std::string & name,const std::string & defval) const139 std::string TaskContext::Get(const std::string &name,
140 const std::string &defval) const {
141 return Get(name, defval.c_str());
142 }
143
Get(const std::string & name,int defval) const144 int TaskContext::Get(const std::string &name, int defval) const {
145 std::string value = Get(name, "");
146 return ParseInt32WithDefault(value, defval);
147 }
148
Get(const std::string & name,int64 defval) const149 int64 TaskContext::Get(const std::string &name, int64 defval) const {
150 std::string value = Get(name, "");
151 return ParseInt64WithDefault(value, defval);
152 }
153
Get(const std::string & name,double defval) const154 double TaskContext::Get(const std::string &name, double defval) const {
155 std::string value = Get(name, "");
156 return ParseDoubleWithDefault(value, defval);
157 }
158
Get(const std::string & name,bool defval) const159 bool TaskContext::Get(const std::string &name, bool defval) const {
160 std::string value = Get(name, "");
161 return value.empty() ? defval : value == "true";
162 }
163
InputFile(const TaskInput & input)164 std::string TaskContext::InputFile(const TaskInput &input) {
165 if (input.part_size() == 0) {
166 TC_LOG(ERROR) << "No file for TaskInput " << input.name();
167 return "";
168 }
169 if (input.part_size() > 1) {
170 TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name();
171 }
172 return input.part(0).file_pattern();
173 }
174
Supports(const TaskInput & input,const std::string & file_format,const std::string & record_format)175 bool TaskContext::Supports(const TaskInput &input,
176 const std::string &file_format,
177 const std::string &record_format) {
178 // Check file format.
179 if (input.file_format_size() > 0) {
180 bool found = false;
181 for (int i = 0; i < input.file_format_size(); ++i) {
182 if (input.file_format(i) == file_format) {
183 found = true;
184 break;
185 }
186 }
187 if (!found) return false;
188 }
189
190 // Check record format.
191 if (input.record_format_size() > 0) {
192 bool found = false;
193 for (int i = 0; i < input.record_format_size(); ++i) {
194 if (input.record_format(i) == record_format) {
195 found = true;
196 break;
197 }
198 }
199 if (!found) return false;
200 }
201
202 return true;
203 }
204
205 } // namespace nlp_core
206 } // namespace libtextclassifier
207