1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_ 19 20 #include "utils/base/arena.h" 21 #include "utils/base/logging.h" 22 #include "utils/flatbuffers/mutable.h" 23 #include "utils/flatbuffers/reflection.h" 24 #include "utils/strings/stringpiece.h" 25 #include "utils/utf8/unicodetext.h" 26 #include "flatbuffers/base.h" 27 #include "flatbuffers/reflection.h" 28 29 namespace libtextclassifier3::grammar { 30 31 // A semantic value as a typed, arena-allocated flatbuffer. 32 // This denotes the possible results of the evaluation of a semantic expression. 33 class SemanticValue { 34 public: 35 // Creates an arena allocated semantic value. 36 template <typename T> Create(const T value,UnsafeArena * arena)37 static const SemanticValue* Create(const T value, UnsafeArena* arena) { 38 static_assert(!std::is_pointer<T>() && std::is_scalar<T>()); 39 if (char* buffer = reinterpret_cast<char*>( 40 arena->AllocAligned(sizeof(T), alignof(T)))) { 41 flatbuffers::WriteScalar<T>(buffer, value); 42 return arena->AllocAndInit<SemanticValue>( 43 libtextclassifier3::flatbuffers_base_type<T>::value, 44 StringPiece(buffer, sizeof(T))); 45 } 46 return nullptr; 47 } 48 49 template <> Create(const StringPiece value,UnsafeArena * arena)50 const SemanticValue* Create(const StringPiece value, UnsafeArena* arena) { 51 return arena->AllocAndInit<SemanticValue>(reflection::BaseType::String, 52 value); 53 } 54 55 template <> Create(const UnicodeText value,UnsafeArena * arena)56 const SemanticValue* Create(const UnicodeText value, UnsafeArena* arena) { 57 return arena->AllocAndInit<SemanticValue>( 58 reflection::BaseType::String, 59 StringPiece(value.data(), value.size_bytes())); 60 } 61 62 template <> Create(const MutableFlatbuffer * value,UnsafeArena * arena)63 const SemanticValue* Create(const MutableFlatbuffer* value, 64 UnsafeArena* arena) { 65 const std::string buffer = value->Serialize(); 66 return Create( 67 value->type(), 68 StringPiece(arena->Memdup(buffer.data(), buffer.size()), buffer.size()), 69 arena); 70 } 71 Create(const reflection::Object * type,const StringPiece data,UnsafeArena * arena)72 static const SemanticValue* Create(const reflection::Object* type, 73 const StringPiece data, 74 UnsafeArena* arena) { 75 return arena->AllocAndInit<SemanticValue>(type, data); 76 } 77 Create(const reflection::BaseType base_type,const StringPiece data,UnsafeArena * arena)78 static const SemanticValue* Create(const reflection::BaseType base_type, 79 const StringPiece data, 80 UnsafeArena* arena) { 81 return arena->AllocAndInit<SemanticValue>(base_type, data); 82 } 83 84 template <typename T> Create(const reflection::BaseType base_type,const T value,UnsafeArena * arena)85 static const SemanticValue* Create(const reflection::BaseType base_type, 86 const T value, UnsafeArena* arena) { 87 switch (base_type) { 88 case reflection::BaseType::Bool: 89 return Create( 90 static_cast< 91 flatbuffers_cpp_type<reflection::BaseType::Bool>::value>(value), 92 arena); 93 case reflection::BaseType::Byte: 94 return Create( 95 static_cast< 96 flatbuffers_cpp_type<reflection::BaseType::Byte>::value>(value), 97 arena); 98 case reflection::BaseType::UByte: 99 return Create( 100 static_cast< 101 flatbuffers_cpp_type<reflection::BaseType::UByte>::value>( 102 value), 103 arena); 104 case reflection::BaseType::Short: 105 return Create( 106 static_cast< 107 flatbuffers_cpp_type<reflection::BaseType::Short>::value>( 108 value), 109 arena); 110 case reflection::BaseType::UShort: 111 return Create( 112 static_cast< 113 flatbuffers_cpp_type<reflection::BaseType::UShort>::value>( 114 value), 115 arena); 116 case reflection::BaseType::Int: 117 return Create( 118 static_cast<flatbuffers_cpp_type<reflection::BaseType::Int>::value>( 119 value), 120 arena); 121 case reflection::BaseType::UInt: 122 return Create( 123 static_cast< 124 flatbuffers_cpp_type<reflection::BaseType::UInt>::value>(value), 125 arena); 126 case reflection::BaseType::Long: 127 return Create( 128 static_cast< 129 flatbuffers_cpp_type<reflection::BaseType::Long>::value>(value), 130 arena); 131 case reflection::BaseType::ULong: 132 return Create( 133 static_cast< 134 flatbuffers_cpp_type<reflection::BaseType::ULong>::value>( 135 value), 136 arena); 137 case reflection::BaseType::Float: 138 return Create( 139 static_cast< 140 flatbuffers_cpp_type<reflection::BaseType::Float>::value>( 141 value), 142 arena); 143 case reflection::BaseType::Double: 144 return Create( 145 static_cast< 146 flatbuffers_cpp_type<reflection::BaseType::Double>::value>( 147 value), 148 arena); 149 default: { 150 TC3_LOG(ERROR) << "Unhandled type: " << base_type; 151 return nullptr; 152 } 153 } 154 } 155 SemanticValue(const reflection::BaseType base_type,const StringPiece data)156 explicit SemanticValue(const reflection::BaseType base_type, 157 const StringPiece data) 158 : base_type_(base_type), type_(nullptr), data_(data) {} SemanticValue(const reflection::Object * type,const StringPiece data)159 explicit SemanticValue(const reflection::Object* type, const StringPiece data) 160 : base_type_(reflection::BaseType::Obj), type_(type), data_(data) {} 161 162 template <typename T> Has()163 bool Has() const { 164 return base_type_ == libtextclassifier3::flatbuffers_base_type<T>::value; 165 } 166 167 template <> 168 bool Has<flatbuffers::Table>() const { 169 return base_type_ == reflection::BaseType::Obj; 170 } 171 172 template <typename T = flatbuffers::Table> Table()173 const T* Table() const { 174 TC3_CHECK(Has<flatbuffers::Table>()); 175 return flatbuffers::GetRoot<T>( 176 reinterpret_cast<const unsigned char*>(data_.data())); 177 } 178 179 template <typename T> Value()180 const T Value() const { 181 TC3_CHECK(Has<T>()); 182 return flatbuffers::ReadScalar<T>(data_.data()); 183 } 184 185 template <> 186 const StringPiece Value<StringPiece>() const { 187 TC3_CHECK(Has<StringPiece>()); 188 return data_; 189 } 190 191 template <> 192 const std::string Value<std::string>() const { 193 TC3_CHECK(Has<StringPiece>()); 194 return data_.ToString(); 195 } 196 197 template <> 198 const UnicodeText Value<UnicodeText>() const { 199 TC3_CHECK(Has<StringPiece>()); 200 return UTF8ToUnicodeText(data_, /*do_copy=*/false); 201 } 202 base_type()203 const reflection::BaseType base_type() const { return base_type_; } type()204 const reflection::Object* type() const { return type_; } 205 206 private: 207 // The base type. 208 const reflection::BaseType base_type_; 209 210 // The object type of the value. 211 const reflection::Object* type_; 212 213 StringPiece data_; 214 }; 215 216 } // namespace libtextclassifier3::grammar 217 218 #endif // LIBTEXTCLASSIFIER_UTILS_GRAMMAR_SEMANTICS_VALUE_H_ 219