1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/zlib-utils.h"
18
19 #include <memory>
20
21 #include "utils/base/logging.h"
22 #include "utils/intents/zlib-utils.h"
23 #include "utils/resources.h"
24 #include "utils/zlib/zlib.h"
25
26 namespace libtextclassifier3 {
27
28 // Compress rule fields in the model.
CompressModel(ModelT * model)29 bool CompressModel(ModelT* model) {
30 std::unique_ptr<ZlibCompressor> zlib_compressor = ZlibCompressor::Instance();
31 if (!zlib_compressor) {
32 TC3_LOG(ERROR) << "Cannot compress model.";
33 return false;
34 }
35
36 // Compress regex rules.
37 if (model->regex_model != nullptr) {
38 for (int i = 0; i < model->regex_model->patterns.size(); i++) {
39 RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
40 pattern->compressed_pattern.reset(new CompressedBufferT);
41 zlib_compressor->Compress(pattern->pattern,
42 pattern->compressed_pattern.get());
43 pattern->pattern.clear();
44 }
45 }
46
47 // Compress date-time rules.
48 if (model->datetime_model != nullptr) {
49 for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
50 DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
51 for (int j = 0; j < pattern->regexes.size(); j++) {
52 DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
53 regex->compressed_pattern.reset(new CompressedBufferT);
54 zlib_compressor->Compress(regex->pattern,
55 regex->compressed_pattern.get());
56 regex->pattern.clear();
57 }
58 }
59 for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
60 DatetimeModelExtractorT* extractor =
61 model->datetime_model->extractors[i].get();
62 extractor->compressed_pattern.reset(new CompressedBufferT);
63 zlib_compressor->Compress(extractor->pattern,
64 extractor->compressed_pattern.get());
65 extractor->pattern.clear();
66 }
67 }
68
69 // Compress resources.
70 if (model->resources != nullptr) {
71 CompressResources(model->resources.get());
72 }
73
74 // Compress intent generator.
75 if (model->intent_options != nullptr) {
76 CompressIntentModel(model->intent_options.get());
77 }
78
79 return true;
80 }
81
DecompressModel(ModelT * model)82 bool DecompressModel(ModelT* model) {
83 std::unique_ptr<ZlibDecompressor> zlib_decompressor =
84 ZlibDecompressor::Instance();
85 if (!zlib_decompressor) {
86 TC3_LOG(ERROR) << "Cannot initialize decompressor.";
87 return false;
88 }
89
90 // Decompress regex rules.
91 if (model->regex_model != nullptr) {
92 for (int i = 0; i < model->regex_model->patterns.size(); i++) {
93 RegexModel_::PatternT* pattern = model->regex_model->patterns[i].get();
94 if (!zlib_decompressor->MaybeDecompress(pattern->compressed_pattern.get(),
95 &pattern->pattern)) {
96 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
97 return false;
98 }
99 pattern->compressed_pattern.reset(nullptr);
100 }
101 }
102
103 // Decompress date-time rules.
104 if (model->datetime_model != nullptr) {
105 for (int i = 0; i < model->datetime_model->patterns.size(); i++) {
106 DatetimeModelPatternT* pattern = model->datetime_model->patterns[i].get();
107 for (int j = 0; j < pattern->regexes.size(); j++) {
108 DatetimeModelPattern_::RegexT* regex = pattern->regexes[j].get();
109 if (!zlib_decompressor->MaybeDecompress(regex->compressed_pattern.get(),
110 ®ex->pattern)) {
111 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i << " " << j;
112 return false;
113 }
114 regex->compressed_pattern.reset(nullptr);
115 }
116 }
117 for (int i = 0; i < model->datetime_model->extractors.size(); i++) {
118 DatetimeModelExtractorT* extractor =
119 model->datetime_model->extractors[i].get();
120 if (!zlib_decompressor->MaybeDecompress(
121 extractor->compressed_pattern.get(), &extractor->pattern)) {
122 TC3_LOG(ERROR) << "Cannot decompress pattern: " << i;
123 return false;
124 }
125 extractor->compressed_pattern.reset(nullptr);
126 }
127 }
128 return true;
129 }
130
CompressSerializedModel(const std::string & model)131 std::string CompressSerializedModel(const std::string& model) {
132 std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
133 TC3_CHECK(unpacked_model != nullptr);
134 TC3_CHECK(CompressModel(unpacked_model.get()));
135 flatbuffers::FlatBufferBuilder builder;
136 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
137 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
138 builder.GetSize());
139 }
140
141 } // namespace libtextclassifier3
142