1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/grammar/semantics/evaluators/compose-eval.h"
18
19 #include "utils/base/status_macros.h"
20 #include "utils/strings/stringpiece.h"
21
22 namespace libtextclassifier3::grammar {
23 namespace {
24
25 // Tries setting a singular field.
26 template <typename T>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)27 Status TrySetField(const reflection::Field* field, const SemanticValue* value,
28 MutableFlatbuffer* result) {
29 if (!result->Set<T>(field, value->Value<T>())) {
30 return Status(StatusCode::INVALID_ARGUMENT, "Could not set field.");
31 }
32 return Status::OK;
33 }
34
35 template <>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)36 Status TrySetField<flatbuffers::Table>(const reflection::Field* field,
37 const SemanticValue* value,
38 MutableFlatbuffer* result) {
39 if (!result->Mutable(field)->MergeFrom(value->Table())) {
40 return Status(StatusCode::INVALID_ARGUMENT,
41 "Could not set sub-field in result.");
42 }
43 return Status::OK;
44 }
45
46 // Tries adding a value to a repeated field.
47 template <typename T>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)48 Status TryAddField(const reflection::Field* field, const SemanticValue* value,
49 MutableFlatbuffer* result) {
50 if (!result->Repeated(field)->Add(value->Value<T>())) {
51 return Status(StatusCode::INVALID_ARGUMENT, "Could not add field.");
52 }
53 return Status::OK;
54 }
55
56 template <>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)57 Status TryAddField<flatbuffers::Table>(const reflection::Field* field,
58 const SemanticValue* value,
59 MutableFlatbuffer* result) {
60 if (!result->Repeated(field)->Add()->MergeFrom(value->Table())) {
61 return Status(StatusCode::INVALID_ARGUMENT,
62 "Could not add message to repeated field.");
63 }
64 return Status::OK;
65 }
66
67 // Tries adding or setting a value for a field.
68 template <typename T>
TrySetOrAddValue(const FlatbufferFieldPath * field_path,const SemanticValue * value,MutableFlatbuffer * result)69 Status TrySetOrAddValue(const FlatbufferFieldPath* field_path,
70 const SemanticValue* value, MutableFlatbuffer* result) {
71 MutableFlatbuffer* parent;
72 const reflection::Field* field;
73 if (!result->GetFieldWithParent(field_path, &parent, &field)) {
74 return Status(StatusCode::INVALID_ARGUMENT, "Could not get field.");
75 }
76 if (field->type()->base_type() == reflection::Vector) {
77 return TryAddField<T>(field, value, parent);
78 } else {
79 return TrySetField<T>(field, value, parent);
80 }
81 }
82
83 } // namespace
84
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const85 StatusOr<const SemanticValue*> ComposeEvaluator::Apply(
86 const EvalContext& context, const SemanticExpression* expression,
87 UnsafeArena* arena) const {
88 const ComposeExpression* compose_expression =
89 expression->expression_as_ComposeExpression();
90 std::unique_ptr<MutableFlatbuffer> result =
91 semantic_value_builder_.NewTable(compose_expression->type());
92
93 if (result == nullptr) {
94 return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
95 }
96
97 // Evaluate and set fields.
98 if (compose_expression->fields() != nullptr) {
99 for (const ComposeExpression_::Field* field :
100 *compose_expression->fields()) {
101 // Evaluate argument.
102 TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
103 composer_->Apply(context, field->value(), arena));
104 if (value == nullptr) {
105 continue;
106 }
107
108 switch (value->base_type()) {
109 case reflection::BaseType::Bool: {
110 TC3_RETURN_IF_ERROR(
111 TrySetOrAddValue<bool>(field->path(), value, result.get()));
112 break;
113 }
114 case reflection::BaseType::Byte: {
115 TC3_RETURN_IF_ERROR(
116 TrySetOrAddValue<int8>(field->path(), value, result.get()));
117 break;
118 }
119 case reflection::BaseType::UByte: {
120 TC3_RETURN_IF_ERROR(
121 TrySetOrAddValue<uint8>(field->path(), value, result.get()));
122 break;
123 }
124 case reflection::BaseType::Short: {
125 TC3_RETURN_IF_ERROR(
126 TrySetOrAddValue<int16>(field->path(), value, result.get()));
127 break;
128 }
129 case reflection::BaseType::UShort: {
130 TC3_RETURN_IF_ERROR(
131 TrySetOrAddValue<uint16>(field->path(), value, result.get()));
132 break;
133 }
134 case reflection::BaseType::Int: {
135 TC3_RETURN_IF_ERROR(
136 TrySetOrAddValue<int32>(field->path(), value, result.get()));
137 break;
138 }
139 case reflection::BaseType::UInt: {
140 TC3_RETURN_IF_ERROR(
141 TrySetOrAddValue<uint32>(field->path(), value, result.get()));
142 break;
143 }
144 case reflection::BaseType::Long: {
145 TC3_RETURN_IF_ERROR(
146 TrySetOrAddValue<int64>(field->path(), value, result.get()));
147 break;
148 }
149 case reflection::BaseType::ULong: {
150 TC3_RETURN_IF_ERROR(
151 TrySetOrAddValue<uint64>(field->path(), value, result.get()));
152 break;
153 }
154 case reflection::BaseType::Float: {
155 TC3_RETURN_IF_ERROR(
156 TrySetOrAddValue<float>(field->path(), value, result.get()));
157 break;
158 }
159 case reflection::BaseType::Double: {
160 TC3_RETURN_IF_ERROR(
161 TrySetOrAddValue<double>(field->path(), value, result.get()));
162 break;
163 }
164 case reflection::BaseType::String: {
165 TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>(
166 field->path(), value, result.get()));
167 break;
168 }
169 case reflection::BaseType::Obj: {
170 TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>(
171 field->path(), value, result.get()));
172 break;
173 }
174 default:
175 return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type.");
176 }
177 }
178 }
179
180 return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
181 }
182
183 } // namespace libtextclassifier3::grammar
184