• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/regex-match.h"
18 
19 #include <memory>
20 
21 #include "annotator/types.h"
22 
23 #ifndef TC3_DISABLE_LUA
24 #include "utils/lua-utils.h"
25 #ifdef __cplusplus
26 extern "C" {
27 #endif
28 #include "lauxlib.h"
29 #include "lualib.h"
30 #ifdef __cplusplus
31 }
32 #endif
33 #endif
34 
35 namespace libtextclassifier3 {
36 namespace {
37 
38 #ifndef TC3_DISABLE_LUA
39 // Provide a lua environment for running regex match post verification.
40 // It sets up and exposes the match data as well as the context.
41 class LuaVerifier : public LuaEnvironment {
42  public:
43   static std::unique_ptr<LuaVerifier> Create(
44       const std::string& context, const std::string& verifier_code,
45       const UniLib::RegexMatcher* matcher);
46 
47   bool Verify(bool* result);
48 
49  private:
LuaVerifier(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)50   explicit LuaVerifier(const std::string& context,
51                        const std::string& verifier_code,
52                        const UniLib::RegexMatcher* matcher)
53       : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
54   bool Initialize();
55 
56   // Provides details of a capturing group to lua.
57   int GetCapturingGroup();
58 
59   const std::string& context_;
60   const std::string& verifier_code_;
61   const UniLib::RegexMatcher* matcher_;
62 };
63 
Initialize()64 bool LuaVerifier::Initialize() {
65   // Run protected to not lua panic in case of setup failure.
66   return RunProtected([this] {
67            LoadDefaultLibraries();
68 
69            // Expose context of the match as `context` global variable.
70            PushString(context_);
71            lua_setglobal(state_, "context");
72 
73            // Expose match array as `match` global variable.
74            // Each entry `match[i]` exposes the ith capturing group as:
75            //   * `begin`: span start
76            //   * `end`: span end
77            //   * `text`: the text
78            PushLazyObject(&LuaVerifier::GetCapturingGroup);
79            lua_setglobal(state_, "match");
80            return LUA_OK;
81          }) == LUA_OK;
82 }
83 
Create(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)84 std::unique_ptr<LuaVerifier> LuaVerifier::Create(
85     const std::string& context, const std::string& verifier_code,
86     const UniLib::RegexMatcher* matcher) {
87   auto verifier = std::unique_ptr<LuaVerifier>(
88       new LuaVerifier(context, verifier_code, matcher));
89   if (!verifier->Initialize()) {
90     TC3_LOG(ERROR) << "Could not initialize lua environment.";
91     return nullptr;
92   }
93   return verifier;
94 }
95 
GetCapturingGroup()96 int LuaVerifier::GetCapturingGroup() {
97   if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
98     TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
99                    << lua_type(state_, /*idx=*/-1);
100     lua_error(state_);
101     return 0;
102   }
103   const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
104   int status = UniLib::RegexMatcher::kNoError;
105   const CodepointSpan span = {matcher_->Start(group_id, &status),
106                               matcher_->End(group_id, &status)};
107   std::string text = matcher_->Group(group_id, &status).ToUTF8String();
108   if (status != UniLib::RegexMatcher::kNoError) {
109     TC3_LOG(ERROR) << "Could not extract span from capturing group.";
110     lua_error(state_);
111     return 0;
112   }
113   lua_newtable(state_);
114   lua_pushinteger(state_, span.first);
115   lua_setfield(state_, /*idx=*/-2, "begin");
116   lua_pushinteger(state_, span.second);
117   lua_setfield(state_, /*idx=*/-2, "end");
118   PushString(text);
119   lua_setfield(state_, /*idx=*/-2, "text");
120   return 1;
121 }
122 
Verify(bool * result)123 bool LuaVerifier::Verify(bool* result) {
124   if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
125                       /*name=*/nullptr) != LUA_OK) {
126     TC3_LOG(ERROR) << "Could not load verifier snippet.";
127     return false;
128   }
129 
130   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
131     TC3_LOG(ERROR) << "Could not run verifier snippet.";
132     return false;
133   }
134 
135   if (RunProtected(
136           [this, result] {
137             if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
138               TC3_LOG(ERROR) << "Unexpected verification result type: "
139                              << lua_type(state_, /*idx=*/-1);
140               lua_error(state_);
141               return LUA_ERRRUN;
142             }
143             *result = lua_toboolean(state_, /*idx=*/-1);
144             return LUA_OK;
145           },
146           /*num_args=*/1) != LUA_OK) {
147     TC3_LOG(ERROR) << "Could not read lua result.";
148     return false;
149   }
150   return true;
151 }
152 #endif  // TC3_DISABLE_LUA
153 
154 }  // namespace
155 
GetCapturingGroupText(const UniLib::RegexMatcher * matcher,const int group_id)156 Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
157                                             const int group_id) {
158   int status = UniLib::RegexMatcher::kNoError;
159   std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
160   if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
161     return Optional<std::string>();
162   }
163   return Optional<std::string>(group_text);
164 }
165 
VerifyMatch(const std::string & context,const UniLib::RegexMatcher * matcher,const std::string & lua_verifier_code)166 bool VerifyMatch(const std::string& context,
167                  const UniLib::RegexMatcher* matcher,
168                  const std::string& lua_verifier_code) {
169   bool status = false;
170 #ifndef TC3_DISABLE_LUA
171   auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
172   if (verifier == nullptr) {
173     TC3_LOG(ERROR) << "Could not create verifier.";
174     return false;
175   }
176   if (!verifier->Verify(&status)) {
177     TC3_LOG(ERROR) << "Could not create verifier.";
178     return false;
179   }
180 #endif  // TC3_DISABLE_LUA
181   return status;
182 }
183 
184 }  // namespace libtextclassifier3
185