1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/zlib-utils.h"
18
19 #include <memory>
20
21 #include "utils/base/logging.h"
22 #include "utils/intents/zlib-utils.h"
23 #include "utils/zlib/zlib.h"
24
25 namespace libtextclassifier3 {
26
27 // Compress rule fields in the model.
CompressModel(ModelT * model)28 bool CompressModel(ModelT* model) {
29 std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
30 if (!zlib_compressor) {
31 TC3_LOG(ERROR) << "Cannot compress model.";
32 return false;
33 }
34
35 // Compress regex rules.
36 if (model->regex_model != nullptr) {
37 for (int i = 0; i < model->regex_model->patterns.size(); i++) {
38 RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
39 pattern->compressed_pattern.reset(new CompressedBufferT);
40 zlib_compressor->Compress(pattern->pattern,
41 pattern->compressed_pattern.get());
42 pattern->pattern.clear();
43 }
44 }
45
46 // Compress date-time rules.
47 if (model->datetime_model != nullptr) {
48 for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
49 DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
50 for (int j = 0; j < pattern->regexes.size(); j++) {
51 DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
52 regex->compressed_pattern.reset(new CompressedBufferT);
53 zlib_compressor->Compress(regex->pattern,
54 regex->compressed_pattern.get());
55 regex->pattern.clear();
56 }
57 }
58 for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
59 DatetimeModelExtractorT* extractor =
60 model->datetime_model->extractors[i].get();
61 extractor->compressed_pattern.reset(new CompressedBufferT);
62 zlib_compressor->Compress(extractor->pattern,
63 extractor->compressed_pattern.get());
64 extractor->pattern.clear();
65 }
66 }
67
68 // Compress intent generator.
69 if (model->intent_options != nullptr) {
70 CompressIntentModel(model->intent_options.get());
71 }
72
73 return true;
74 }
75
DecompressModel(ModelT * model)76 bool DecompressModel(ModelT* model) {
77 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
78 ZlibDecompressor::Instance();
79 if (!zlib_decompressor) {
80 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
81 return false;
82 }
83
84 // Decompress regex rules.
85 if (model->regex_model != nullptr) {
86 for (int i = 0; i < model->regex_model->patterns.size(); i++) {
87 RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
88 if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(),
89 &pattern->pattern)) {
90 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
91 return false;
92 }
93 pattern->compressed_pattern.reset(nullptr);
94 }
95 }
96
97 // Decompress date-time rules.
98 if (model->datetime_model != nullptr) {
99 for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
100 DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
101 for (int j = 0; j < pattern->regexes.size(); j++) {
102 DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
103 if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(),
104 ®ex->pattern)) {
105 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j;
106 return false;
107 }
108 regex->compressed_pattern.reset(nullptr);
109 }
110 }
111 for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
112 DatetimeModelExtractorT* extractor =
113 model->datetime_model->extractors[i].get();
114 if (!zlib_decompressor->MaybeDecompress(
115 extractor->compressed_pattern.get(), &extractor->pattern)) {
116 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
117 return false;
118 }
119 extractor->compressed_pattern.reset(nullptr);
120 }
121 }
122
123 if (model->intent_options != nullptr) {
124 DecompressIntentModel(model->intent_options.get());
125 }
126
127 return true;
128 }
129
CompressSerializedModel(const std::string & model)130 std::string CompressSerializedModel(const std::string& model) {
131 std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
132 TC3_CHECK(unpacked_model != nullptr);
133 TC3_CHECK(CompressModel(unpacked_model.get()));
134 flatbuffers::FlatBufferBuilder builder;
135 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
136 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
137 builder.GetSize());
138 }
139
140 } // namespace libtextclassifier3
141