• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
19 
20 #include "utils/base/arena.h"
21 #include "utils/base/logging.h"
22 #include "utils/flatbuffers/mutable.h"
23 #include "utils/flatbuffers/reflection.h"
24 #include "utils/strings/stringpiece.h"
25 #include "utils/utf8/unicodetext.h"
26 #include "flatbuffers/base.h"
27 #include "flatbuffers/reflection.h"
28 
29 namespace libtextclassifier3::grammar {
30 
31 // A semantic value as a typed, arena-allocated flatbuffer.
32 // This denotes the possible results of the evaluation of a semantic expression.
33 class SemanticValue {
34  public:
35   // Creates an arena allocated semantic value.
36   template <typename T>
Create(const T value,UnsafeArena * arena)37   static const SemanticValue* Create(const T value, UnsafeArena* arena) {
38     static_assert(!std::is_pointer<T>() && std::is_scalar<T>());
39     if (char* buffer = reinterpret_cast<char*>(
40             arena->AllocAligned(sizeof(T), alignof(T)))) {
41       flatbuffers::WriteScalar<T>(buffer, value);
42       return arena->AllocAndInit<SemanticValue>(
43           libtextclassifier3::flatbuffers_base_type<T>::value,
44           StringPiece(buffer, sizeof(T)));
45     }
46     return nullptr;
47   }
48 
49   template <>
Create(const StringPiece value,UnsafeArena * arena)50   const SemanticValue* Create(const StringPiece value, UnsafeArena* arena) {
51     return arena->AllocAndInit<SemanticValue>(reflection::BaseType::String,
52                                               value);
53   }
54 
55   template <>
Create(const UnicodeText value,UnsafeArena * arena)56   const SemanticValue* Create(const UnicodeText value, UnsafeArena* arena) {
57     return arena->AllocAndInit<SemanticValue>(
58         reflection::BaseType::String,
59         StringPiece(value.data(), value.size_bytes()));
60   }
61 
62   template <>
Create(const MutableFlatbuffer * value,UnsafeArena * arena)63   const SemanticValue* Create(const MutableFlatbuffer* value,
64                               UnsafeArena* arena) {
65     const std::string buffer = value->Serialize();
66     return Create(
67         value->type(),
68         StringPiece(arena->Memdup(buffer.data(), buffer.size()), buffer.size()),
69         arena);
70   }
71 
Create(const reflection::Object * type,const StringPiece data,UnsafeArena * arena)72   static const SemanticValue* Create(const reflection::Object* type,
73                                      const StringPiece data,
74                                      UnsafeArena* arena) {
75     return arena->AllocAndInit<SemanticValue>(type, data);
76   }
77 
Create(const reflection::BaseType base_type,const StringPiece data,UnsafeArena * arena)78   static const SemanticValue* Create(const reflection::BaseType base_type,
79                                      const StringPiece data,
80                                      UnsafeArena* arena) {
81     return arena->AllocAndInit<SemanticValue>(base_type, data);
82   }
83 
84   template <typename T>
Create(const reflection::BaseType base_type,const T value,UnsafeArena * arena)85   static const SemanticValue* Create(const reflection::BaseType base_type,
86                                      const T value, UnsafeArena* arena) {
87     switch (base_type) {
88       case reflection::BaseType::Bool:
89         return Create(
90             static_cast<
91                 flatbuffers_cpp_type<reflection::BaseType::Bool>::value>(value),
92             arena);
93       case reflection::BaseType::Byte:
94         return Create(
95             static_cast<
96                 flatbuffers_cpp_type<reflection::BaseType::Byte>::value>(value),
97             arena);
98       case reflection::BaseType::UByte:
99         return Create(
100             static_cast<
101                 flatbuffers_cpp_type<reflection::BaseType::UByte>::value>(
102                 value),
103             arena);
104       case reflection::BaseType::Short:
105         return Create(
106             static_cast<
107                 flatbuffers_cpp_type<reflection::BaseType::Short>::value>(
108                 value),
109             arena);
110       case reflection::BaseType::UShort:
111         return Create(
112             static_cast<
113                 flatbuffers_cpp_type<reflection::BaseType::UShort>::value>(
114                 value),
115             arena);
116       case reflection::BaseType::Int:
117         return Create(
118             static_cast<flatbuffers_cpp_type<reflection::BaseType::Int>::value>(
119                 value),
120             arena);
121       case reflection::BaseType::UInt:
122         return Create(
123             static_cast<
124                 flatbuffers_cpp_type<reflection::BaseType::UInt>::value>(value),
125             arena);
126       case reflection::BaseType::Long:
127         return Create(
128             static_cast<
129                 flatbuffers_cpp_type<reflection::BaseType::Long>::value>(value),
130             arena);
131       case reflection::BaseType::ULong:
132         return Create(
133             static_cast<
134                 flatbuffers_cpp_type<reflection::BaseType::ULong>::value>(
135                 value),
136             arena);
137       case reflection::BaseType::Float:
138         return Create(
139             static_cast<
140                 flatbuffers_cpp_type<reflection::BaseType::Float>::value>(
141                 value),
142             arena);
143       case reflection::BaseType::Double:
144         return Create(
145             static_cast<
146                 flatbuffers_cpp_type<reflection::BaseType::Double>::value>(
147                 value),
148             arena);
149       default: {
150         TC3_LOG(ERROR) << "Unhandled type: " << base_type;
151         return nullptr;
152       }
153     }
154   }
155 
SemanticValue(const reflection::BaseType base_type,const StringPiece data)156   explicit SemanticValue(const reflection::BaseType base_type,
157                          const StringPiece data)
158       : base_type_(base_type), type_(nullptr), data_(data) {}
SemanticValue(const reflection::Object * type,const StringPiece data)159   explicit SemanticValue(const reflection::Object* type, const StringPiece data)
160       : base_type_(reflection::BaseType::Obj), type_(type), data_(data) {}
161 
162   template <typename T>
Has()163   bool Has() const {
164     return base_type_ == libtextclassifier3::flatbuffers_base_type<T>::value;
165   }
166 
167   template <>
168   bool Has<flatbuffers::Table>() const {
169     return base_type_ == reflection::BaseType::Obj;
170   }
171 
172   template <typename T = flatbuffers::Table>
Table()173   const T* Table() const {
174     TC3_CHECK(Has<flatbuffers::Table>());
175     return flatbuffers::GetRoot<T>(
176         reinterpret_cast<const unsigned char*>(data_.data()));
177   }
178 
179   template <typename T>
Value()180   const T Value() const {
181     TC3_CHECK(Has<T>());
182     return flatbuffers::ReadScalar<T>(data_.data());
183   }
184 
185   template <>
186   const StringPiece Value<StringPiece>() const {
187     TC3_CHECK(Has<StringPiece>());
188     return data_;
189   }
190 
191   template <>
192   const std::string Value<std::string>() const {
193     TC3_CHECK(Has<StringPiece>());
194     return data_.ToString();
195   }
196 
197   template <>
198   const UnicodeText Value<UnicodeText>() const {
199     TC3_CHECK(Has<StringPiece>());
200     return UTF8ToUnicodeText(data_, /*do_copy=*/false);
201   }
202 
base_type()203   const reflection::BaseType base_type() const { return base_type_; }
type()204   const reflection::Object* type() const { return type_; }
205 
206  private:
207   // The base type.
208   const reflection::BaseType base_type_;
209 
210   // The object type of the value.
211   const reflection::Object* type_;
212 
213   StringPiece data_;
214 };
215 
216 }  // namespace libtextclassifier3::grammar
217 
218 #endif  // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_
219