1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/regex-match.h"
18
19 #include <memory>
20
21 #include "annotator/types.h"
22
23 #ifndef TC3_DISABLE_LUA
24 #include "utils/lua-utils.h"
25 #ifdef __cplusplus
26 extern "C" {
27 #endif
28 #include "lauxlib.h"
29 #include "lualib.h"
30 #ifdef __cplusplus
31 }
32 #endif
33 #endif
34
35 namespace libtextclassifier3 {
36 namespace {
37
38 #ifndef TC3_DISABLE_LUA
39 // Provide a lua environment for running regex match post verification.
40 // It sets up and exposes the match data as well as the context.
41 class LuaVerifier : public LuaEnvironment {
42 public:
43 static std::unique_ptr<LuaVerifier> Create(
44 const std::string& context, const std::string& verifier_code,
45 const UniLib::RegexMatcher* matcher);
46
47 bool Verify(bool* result);
48
49 private:
LuaVerifier(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)50 explicit LuaVerifier(const std::string& context,
51 const std::string& verifier_code,
52 const UniLib::RegexMatcher* matcher)
53 : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
54 bool Initialize();
55
56 // Provides details of a capturing group to lua.
57 int GetCapturingGroup();
58
59 const std::string& context_;
60 const std::string& verifier_code_;
61 const UniLib::RegexMatcher* matcher_;
62 };
63
Initialize()64 bool LuaVerifier::Initialize() {
65 // Run protected to not lua panic in case of setup failure.
66 return RunProtected([this] {
67 LoadDefaultLibraries();
68
69 // Expose context of the match as `context` global variable.
70 PushString(context_);
71 lua_setglobal(state_, "context");
72
73 // Expose match array as `match` global variable.
74 // Each entry `match[i]` exposes the ith capturing group as:
75 // * `begin`: span start
76 // * `end`: span end
77 // * `text`: the text
78 PushLazyObject(&LuaVerifier::GetCapturingGroup);
79 lua_setglobal(state_, "match");
80 return LUA_OK;
81 }) == LUA_OK;
82 }
83
Create(const std::string & context,const std::string & verifier_code,const UniLib::RegexMatcher * matcher)84 std::unique_ptr<LuaVerifier> LuaVerifier::Create(
85 const std::string& context, const std::string& verifier_code,
86 const UniLib::RegexMatcher* matcher) {
87 auto verifier = std::unique_ptr<LuaVerifier>(
88 new LuaVerifier(context, verifier_code, matcher));
89 if (!verifier->Initialize()) {
90 TC3_LOG(ERROR) << "Could not initialize lua environment.";
91 return nullptr;
92 }
93 return verifier;
94 }
95
GetCapturingGroup()96 int LuaVerifier::GetCapturingGroup() {
97 if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
98 TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
99 << lua_type(state_, /*idx=*/-1);
100 lua_error(state_);
101 return 0;
102 }
103 const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
104 int status = UniLib::RegexMatcher::kNoError;
105 const CodepointSpan span = {matcher_->Start(group_id, &status),
106 matcher_->End(group_id, &status)};
107 std::string text = matcher_->Group(group_id, &status).ToUTF8String();
108 if (status != UniLib::RegexMatcher::kNoError) {
109 TC3_LOG(ERROR) << "Could not extract span from capturing group.";
110 lua_error(state_);
111 return 0;
112 }
113 lua_newtable(state_);
114 lua_pushinteger(state_, span.first);
115 lua_setfield(state_, /*idx=*/-2, "begin");
116 lua_pushinteger(state_, span.second);
117 lua_setfield(state_, /*idx=*/-2, "end");
118 PushString(text);
119 lua_setfield(state_, /*idx=*/-2, "text");
120 return 1;
121 }
122
Verify(bool * result)123 bool LuaVerifier::Verify(bool* result) {
124 if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
125 /*name=*/nullptr) != LUA_OK) {
126 TC3_LOG(ERROR) << "Could not load verifier snippet.";
127 return false;
128 }
129
130 if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
131 TC3_LOG(ERROR) << "Could not run verifier snippet.";
132 return false;
133 }
134
135 if (RunProtected(
136 [this, result] {
137 if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
138 TC3_LOG(ERROR) << "Unexpected verification result type: "
139 << lua_type(state_, /*idx=*/-1);
140 lua_error(state_);
141 return LUA_ERRRUN;
142 }
143 *result = lua_toboolean(state_, /*idx=*/-1);
144 return LUA_OK;
145 },
146 /*num_args=*/1) != LUA_OK) {
147 TC3_LOG(ERROR) << "Could not read lua result.";
148 return false;
149 }
150 return true;
151 }
152 #endif // TC3_DISABLE_LUA
153
154 } // namespace
155
GetCapturingGroupText(const UniLib::RegexMatcher * matcher,const int group_id)156 Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
157 const int group_id) {
158 int status = UniLib::RegexMatcher::kNoError;
159 std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
160 if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
161 return Optional<std::string>();
162 }
163 return Optional<std::string>(group_text);
164 }
165
VerifyMatch(const std::string & context,const UniLib::RegexMatcher * matcher,const std::string & lua_verifier_code)166 bool VerifyMatch(const std::string& context,
167 const UniLib::RegexMatcher* matcher,
168 const std::string& lua_verifier_code) {
169 bool status = false;
170 #ifndef TC3_DISABLE_LUA
171 auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
172 if (verifier == nullptr) {
173 TC3_LOG(ERROR) << "Could not create verifier.";
174 return false;
175 }
176 if (!verifier->Verify(&status)) {
177 TC3_LOG(ERROR) << "Could not create verifier.";
178 return false;
179 }
180 #endif // TC3_DISABLE_LUA
181 return status;
182 }
183
184 } // namespace libtextclassifier3
185