/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "actions/lua-utils.h" namespace libtextclassifier3 { namespace { static constexpr const char* kTextKey = "text"; static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc"; static constexpr const char* kGranularityKey = "granularity"; static constexpr const char* kCollectionKey = "collection"; static constexpr const char* kNameKey = "name"; static constexpr const char* kScoreKey = "score"; static constexpr const char* kPriorityScoreKey = "priority_score"; static constexpr const char* kTypeKey = "type"; static constexpr const char* kResponseTextKey = "response_text"; static constexpr const char* kAnnotationKey = "annotation"; static constexpr const char* kSpanKey = "span"; static constexpr const char* kMessageKey = "message"; static constexpr const char* kBeginKey = "begin"; static constexpr const char* kEndKey = "end"; static constexpr const char* kClassificationKey = "classification"; static constexpr const char* kSerializedEntity = "serialized_entity"; static constexpr const char* kEntityKey = "entity"; } // namespace template <> int AnnotationIterator::Item( const std::vector* annotations, StringPiece key, lua_State* state) const { // Lookup annotation by collection. for (const ClassificationResult& annotation : *annotations) { if (key.Equals(annotation.collection)) { PushAnnotation(annotation, entity_data_schema_, env_); return 1; } } TC3_LOG(ERROR) << "No annotation with collection: " << key.ToString() << " found."; lua_error(state); return 0; } template <> int AnnotationIterator::Item( const std::vector* annotations, StringPiece key, lua_State* state) const { // Lookup annotation by name. for (const ActionSuggestionAnnotation& annotation : *annotations) { if (key.Equals(annotation.name)) { PushAnnotation(annotation, entity_data_schema_, env_); return 1; } } TC3_LOG(ERROR) << "No annotation with name: " << key.ToString() << " found."; lua_error(state); return 0; } void PushAnnotation(const ClassificationResult& classification, const reflection::Schema* entity_data_schema, LuaEnvironment* env) { if (entity_data_schema == nullptr || classification.serialized_entity_data.empty()) { // Empty table. lua_newtable(env->state()); } else { env->PushFlatbuffer(entity_data_schema, flatbuffers::GetRoot( classification.serialized_entity_data.data())); } lua_pushinteger(env->state(), classification.datetime_parse_result.time_ms_utc); lua_setfield(env->state(), /*idx=*/-2, kTimeUsecKey); lua_pushinteger(env->state(), classification.datetime_parse_result.granularity); lua_setfield(env->state(), /*idx=*/-2, kGranularityKey); env->PushString(classification.collection); lua_setfield(env->state(), /*idx=*/-2, kCollectionKey); lua_pushnumber(env->state(), classification.score); lua_setfield(env->state(), /*idx=*/-2, kScoreKey); env->PushString(classification.serialized_entity_data); lua_setfield(env->state(), /*idx=*/-2, kSerializedEntity); } void PushAnnotation(const ClassificationResult& classification, StringPiece text, const reflection::Schema* entity_data_schema, LuaEnvironment* env) { PushAnnotation(classification, entity_data_schema, env); env->PushString(text); lua_setfield(env->state(), /*idx=*/-2, kTextKey); } void PushAnnotatedSpan( const AnnotatedSpan& annotated_span, const AnnotationIterator& annotation_iterator, LuaEnvironment* env) { lua_newtable(env->state()); { lua_newtable(env->state()); lua_pushinteger(env->state(), annotated_span.span.first); lua_setfield(env->state(), /*idx=*/-2, kBeginKey); lua_pushinteger(env->state(), annotated_span.span.second); lua_setfield(env->state(), /*idx=*/-2, kEndKey); } lua_setfield(env->state(), /*idx=*/-2, kSpanKey); annotation_iterator.NewIterator(kClassificationKey, &annotated_span.classification, env->state()); lua_setfield(env->state(), /*idx=*/-2, kClassificationKey); } MessageTextSpan ReadSpan(LuaEnvironment* env) { MessageTextSpan span; lua_pushnil(env->state()); while (lua_next(env->state(), /*idx=*/-2)) { const StringPiece key = env->ReadString(/*index=*/-2); if (key.Equals(kMessageKey)) { span.message_index = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kBeginKey)) { span.span.first = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kEndKey)) { span.span.second = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kTextKey)) { span.text = env->ReadString(/*index=*/-1).ToString(); } else { TC3_LOG(INFO) << "Unknown span field: " << key.ToString(); } lua_pop(env->state(), 1); } return span; } int ReadAnnotations(const reflection::Schema* entity_data_schema, LuaEnvironment* env, std::vector* annotations) { if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected annotations table, got: " << lua_type(env->state(), /*idx=*/-1); lua_pop(env->state(), 1); lua_error(env->state()); return LUA_ERRRUN; } // Read actions. lua_pushnil(env->state()); while (lua_next(env->state(), /*idx=*/-2)) { if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected annotation table, got: " << lua_type(env->state(), /*idx=*/-1); lua_pop(env->state(), 1); continue; } annotations->push_back(ReadAnnotation(entity_data_schema, env)); lua_pop(env->state(), 1); } return LUA_OK; } ActionSuggestionAnnotation ReadAnnotation( const reflection::Schema* entity_data_schema, LuaEnvironment* env) { ActionSuggestionAnnotation annotation; lua_pushnil(env->state()); while (lua_next(env->state(), /*idx=*/-2)) { const StringPiece key = env->ReadString(/*index=*/-2); if (key.Equals(kNameKey)) { annotation.name = env->ReadString(/*index=*/-1).ToString(); } else if (key.Equals(kSpanKey)) { annotation.span = ReadSpan(env); } else if (key.Equals(kEntityKey)) { annotation.entity = ReadClassificationResult(entity_data_schema, env); } else { TC3_LOG(ERROR) << "Unknown annotation field: " << key.ToString(); } lua_pop(env->state(), 1); } return annotation; } ClassificationResult ReadClassificationResult( const reflection::Schema* entity_data_schema, LuaEnvironment* env) { ClassificationResult classification; lua_pushnil(env->state()); while (lua_next(env->state(), /*idx=*/-2)) { const StringPiece key = env->ReadString(/*index=*/-2); if (key.Equals(kCollectionKey)) { classification.collection = env->ReadString(/*index=*/-1).ToString(); } else if (key.Equals(kScoreKey)) { classification.score = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kTimeUsecKey)) { classification.datetime_parse_result.time_ms_utc = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kGranularityKey)) { classification.datetime_parse_result.granularity = static_cast( lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kSerializedEntity)) { classification.serialized_entity_data = env->ReadString(/*index=*/-1).ToString(); } else if (key.Equals(kEntityKey)) { auto buffer = ReflectiveFlatbufferBuilder(entity_data_schema).NewRoot(); env->ReadFlatbuffer(buffer.get()); classification.serialized_entity_data = buffer->Serialize(); } else { TC3_LOG(INFO) << "Unknown classification result field: " << key.ToString(); } lua_pop(env->state(), 1); } return classification; } void PushAnnotation(const ActionSuggestionAnnotation& annotation, const reflection::Schema* entity_data_schema, LuaEnvironment* env) { PushAnnotation(annotation.entity, annotation.span.text, entity_data_schema, env); env->PushString(annotation.name); lua_setfield(env->state(), /*idx=*/-2, kNameKey); { lua_newtable(env->state()); lua_pushinteger(env->state(), annotation.span.message_index); lua_setfield(env->state(), /*idx=*/-2, kMessageKey); lua_pushinteger(env->state(), annotation.span.span.first); lua_setfield(env->state(), /*idx=*/-2, kBeginKey); lua_pushinteger(env->state(), annotation.span.span.second); lua_setfield(env->state(), /*idx=*/-2, kEndKey); } lua_setfield(env->state(), /*idx=*/-2, kSpanKey); } void PushAction( const ActionSuggestion& action, const reflection::Schema* entity_data_schema, const AnnotationIterator& annotation_iterator, LuaEnvironment* env) { if (entity_data_schema == nullptr || action.serialized_entity_data.empty()) { // Empty table. lua_newtable(env->state()); } else { env->PushFlatbuffer(entity_data_schema, flatbuffers::GetRoot( action.serialized_entity_data.data())); } env->PushString(action.type); lua_setfield(env->state(), /*idx=*/-2, kTypeKey); env->PushString(action.response_text); lua_setfield(env->state(), /*idx=*/-2, kResponseTextKey); lua_pushnumber(env->state(), action.score); lua_setfield(env->state(), /*idx=*/-2, kScoreKey); lua_pushnumber(env->state(), action.priority_score); lua_setfield(env->state(), /*idx=*/-2, kPriorityScoreKey); annotation_iterator.NewIterator(kAnnotationKey, &action.annotations, env->state()); lua_setfield(env->state(), /*idx=*/-2, kAnnotationKey); } ActionSuggestion ReadAction( const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema, LuaEnvironment* env) { ActionSuggestion action; lua_pushnil(env->state()); while (lua_next(env->state(), /*idx=*/-2)) { const StringPiece key = env->ReadString(/*index=*/-2); if (key.Equals(kResponseTextKey)) { action.response_text = env->ReadString(/*index=*/-1).ToString(); } else if (key.Equals(kTypeKey)) { action.type = env->ReadString(/*index=*/-1).ToString(); } else if (key.Equals(kScoreKey)) { action.score = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kPriorityScoreKey)) { action.priority_score = static_cast(lua_tonumber(env->state(), /*idx=*/-1)); } else if (key.Equals(kAnnotationKey)) { ReadAnnotations(actions_entity_data_schema, env, &action.annotations); } else if (key.Equals(kEntityKey)) { auto buffer = ReflectiveFlatbufferBuilder(actions_entity_data_schema).NewRoot(); env->ReadFlatbuffer(buffer.get()); action.serialized_entity_data = buffer->Serialize(); } else { TC3_LOG(INFO) << "Unknown action field: " << key.ToString(); } lua_pop(env->state(), 1); } return action; } int ReadActions(const reflection::Schema* actions_entity_data_schema, const reflection::Schema* annotations_entity_data_schema, LuaEnvironment* env, std::vector* actions) { if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected actions table, got: " << lua_type(env->state(), /*idx=*/-1); lua_pop(env->state(), 1); lua_error(env->state()); return LUA_ERRRUN; } // Read actions. lua_pushnil(env->state()); while (lua_next(env->state(), /*idx=*/-2)) { if (lua_type(env->state(), /*idx=*/-1) != LUA_TTABLE) { TC3_LOG(ERROR) << "Expected action table, got: " << lua_type(env->state(), /*idx=*/-1); lua_pop(env->state(), 1); continue; } actions->push_back(ReadAction(actions_entity_data_schema, annotations_entity_data_schema, env)); lua_pop(env->state(), /*n=1*/ 1); } lua_pop(env->state(), /*n=*/1); return LUA_OK; } int ConversationIterator::Item(const std::vector* messages, const int64 pos, lua_State* state) const { const ConversationMessage& message = (*messages)[pos]; lua_newtable(state); lua_pushinteger(state, message.user_id); lua_setfield(state, /*idx=*/-2, "user_id"); env_->PushString(message.text); lua_setfield(state, /*idx=*/-2, "text"); lua_pushinteger(state, message.reference_time_ms_utc); lua_setfield(state, /*idx=*/-2, "time_ms_utc"); env_->PushString(message.reference_timezone); lua_setfield(state, /*idx=*/-2, "timezone"); annotated_span_iterator_.NewIterator("annotation", &message.annotations, state); lua_setfield(state, /*idx=*/-2, "annotation"); return 1; } } // namespace libtextclassifier3