• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/grammar/semantics/evaluators/compose-eval.h"
18 
19 #include "utils/base/status_macros.h"
20 #include "utils/strings/stringpiece.h"
21 
22 namespace libtextclassifier3::grammar {
23 namespace {
24 
25 // Tries setting a singular field.
26 template <typename T>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)27 Status TrySetField(const reflection::Field* field, const SemanticValue* value,
28                    MutableFlatbuffer* result) {
29   if (!result->Set<T>(field, value->Value<T>())) {
30     return Status(StatusCode::INVALID_ARGUMENT, "Could not set field.");
31   }
32   return Status::OK;
33 }
34 
35 template <>
TrySetField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)36 Status TrySetField<flatbuffers::Table>(const reflection::Field* field,
37                                        const SemanticValue* value,
38                                        MutableFlatbuffer* result) {
39   if (!result->Mutable(field)->MergeFrom(value->Table())) {
40     return Status(StatusCode::INVALID_ARGUMENT,
41                   "Could not set sub-field in result.");
42   }
43   return Status::OK;
44 }
45 
46 // Tries adding a value to a repeated field.
47 template <typename T>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)48 Status TryAddField(const reflection::Field* field, const SemanticValue* value,
49                    MutableFlatbuffer* result) {
50   if (!result->Repeated(field)->Add(value->Value<T>())) {
51     return Status(StatusCode::INVALID_ARGUMENT, "Could not add field.");
52   }
53   return Status::OK;
54 }
55 
56 template <>
TryAddField(const reflection::Field * field,const SemanticValue * value,MutableFlatbuffer * result)57 Status TryAddField<flatbuffers::Table>(const reflection::Field* field,
58                                        const SemanticValue* value,
59                                        MutableFlatbuffer* result) {
60   if (!result->Repeated(field)->Add()->MergeFrom(value->Table())) {
61     return Status(StatusCode::INVALID_ARGUMENT,
62                   "Could not add message to repeated field.");
63   }
64   return Status::OK;
65 }
66 
67 // Tries adding or setting a value for a field.
68 template <typename T>
TrySetOrAddValue(const FlatbufferFieldPath * field_path,const SemanticValue * value,MutableFlatbuffer * result)69 Status TrySetOrAddValue(const FlatbufferFieldPath* field_path,
70                         const SemanticValue* value, MutableFlatbuffer* result) {
71   MutableFlatbuffer* parent;
72   const reflection::Field* field;
73   if (!result->GetFieldWithParent(field_path, &parent, &field)) {
74     return Status(StatusCode::INVALID_ARGUMENT, "Could not get field.");
75   }
76   if (field->type()->base_type() == reflection::Vector) {
77     return TryAddField<T>(field, value, parent);
78   } else {
79     return TrySetField<T>(field, value, parent);
80   }
81 }
82 
83 }  // namespace
84 
Apply(const EvalContext & context,const SemanticExpression * expression,UnsafeArena * arena) const85 StatusOr<const SemanticValue*> ComposeEvaluator::Apply(
86     const EvalContext& context, const SemanticExpression* expression,
87     UnsafeArena* arena) const {
88   const ComposeExpression* compose_expression =
89       expression->expression_as_ComposeExpression();
90   std::unique_ptr<MutableFlatbuffer> result =
91       semantic_value_builder_.NewTable(compose_expression->type());
92 
93   if (result == nullptr) {
94     return Status(StatusCode::INVALID_ARGUMENT, "Invalid result type.");
95   }
96 
97   // Evaluate and set fields.
98   if (compose_expression->fields() != nullptr) {
99     for (const ComposeExpression_::Field* field :
100          *compose_expression->fields()) {
101       // Evaluate argument.
102       TC3_ASSIGN_OR_RETURN(const SemanticValue* value,
103                            composer_->Apply(context, field->value(), arena));
104       if (value == nullptr) {
105         continue;
106       }
107 
108       switch (value->base_type()) {
109         case reflection::BaseType::Bool: {
110           TC3_RETURN_IF_ERROR(
111               TrySetOrAddValue<bool>(field->path(), value, result.get()));
112           break;
113         }
114         case reflection::BaseType::Byte: {
115           TC3_RETURN_IF_ERROR(
116               TrySetOrAddValue<int8>(field->path(), value, result.get()));
117           break;
118         }
119         case reflection::BaseType::UByte: {
120           TC3_RETURN_IF_ERROR(
121               TrySetOrAddValue<uint8>(field->path(), value, result.get()));
122           break;
123         }
124         case reflection::BaseType::Short: {
125           TC3_RETURN_IF_ERROR(
126               TrySetOrAddValue<int16>(field->path(), value, result.get()));
127           break;
128         }
129         case reflection::BaseType::UShort: {
130           TC3_RETURN_IF_ERROR(
131               TrySetOrAddValue<uint16>(field->path(), value, result.get()));
132           break;
133         }
134         case reflection::BaseType::Int: {
135           TC3_RETURN_IF_ERROR(
136               TrySetOrAddValue<int32>(field->path(), value, result.get()));
137           break;
138         }
139         case reflection::BaseType::UInt: {
140           TC3_RETURN_IF_ERROR(
141               TrySetOrAddValue<uint32>(field->path(), value, result.get()));
142           break;
143         }
144         case reflection::BaseType::Long: {
145           TC3_RETURN_IF_ERROR(
146               TrySetOrAddValue<int64>(field->path(), value, result.get()));
147           break;
148         }
149         case reflection::BaseType::ULong: {
150           TC3_RETURN_IF_ERROR(
151               TrySetOrAddValue<uint64>(field->path(), value, result.get()));
152           break;
153         }
154         case reflection::BaseType::Float: {
155           TC3_RETURN_IF_ERROR(
156               TrySetOrAddValue<float>(field->path(), value, result.get()));
157           break;
158         }
159         case reflection::BaseType::Double: {
160           TC3_RETURN_IF_ERROR(
161               TrySetOrAddValue<double>(field->path(), value, result.get()));
162           break;
163         }
164         case reflection::BaseType::String: {
165           TC3_RETURN_IF_ERROR(TrySetOrAddValue<StringPiece>(
166               field->path(), value, result.get()));
167           break;
168         }
169         case reflection::BaseType::Obj: {
170           TC3_RETURN_IF_ERROR(TrySetOrAddValue<flatbuffers::Table>(
171               field->path(), value, result.get()));
172           break;
173         }
174         default:
175           return Status(StatusCode::INVALID_ARGUMENT, "Unhandled type.");
176       }
177     }
178   }
179 
180   return SemanticValue::Create<const MutableFlatbuffer*>(result.get(), arena);
181 }
182 
183 }  // namespace libtextclassifier3::grammar
184