1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Range coder implementation, based on [1].
17 //
18 // [1] G. N. N. Martin, "Range coding: an algorithm for removing redundancy from
19 // a digitised message", presented to the Video & Data Recording Conference,
20 // held in Southampton, July 24-27, 1979.
21 //
22 #include "tensorflow/contrib/coder/kernels/range_coder.h"
23
24 #include <limits>
25 #include <string>
26
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace tensorflow {
RangeEncoder(int precision)32 RangeEncoder::RangeEncoder(int precision) : precision_(precision) {
33 CHECK_GT(precision, 0);
34 CHECK_LE(precision, 16);
35 }
36
Encode(int32 lower,int32 upper,string * sink)37 void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) {
38 // Input requirement: 0 <= lower < upper <= 2^precision.
39 DCHECK_LE(0, lower);
40 DCHECK_LT(lower, upper);
41 DCHECK_LE(upper, 1 << precision_);
42
43 // `base` and `size` represent a half-open interval [base, base + size).
44 // Loop invariant: 2^16 <= size <= 2^32.
45 //
46 // Note that keeping size above 2^16 is important. Since the interval sizes
47 // are quantized to up to 16 bits, the smallest interval size the encode may
48 // handle is 2^-16. If size is smaller than 2^16, a small interval input may
49 // collapse the encoder range into an empty interval.
50 const uint64 size = static_cast<uint64>(size_minus1_) + 1;
51 DCHECK_NE(size >> 16, 0);
52
53 // For short notation, let u := lower and v := upper.
54 //
55 // The input u, v represents a half-open interval [u, v) / 2^precision.
56 // This narrows the current interval roughly to
57 // [base + (size * u) / 2^precision, base + (size * v) / 2^precision).
58 //
59 // TODO(sjhwang): Try rounding if it helps improve compression ratio, at the
60 // expense of more operations. In the test using Zipf distribution, the
61 // overhead over the theoretical compression ratio was ~0.01%.
62 // NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u`
63 // can be rewritten as `(size - 1) * u + u` and all the computation can be
64 // done in 32-bit mode. If 32-bit multiply is faster, then rewrite.
65 const uint32 a = (size * static_cast<uint64>(lower)) >> precision_;
66 const uint32 b = ((size * static_cast<uint64>(upper)) >> precision_) - 1;
67 DCHECK_LE(a, b);
68
69 // Let's confirm the RHS of a, b fit in uint32 type.
70 // Recall that 0 <= u < 2^precision, and size <= 2^32. Therefore
71 // (size * u) / 2^precision < size <= 2^32,
72 // and the value of a fits in uint32 type. Similarly, since v <= 2^precision,
73 // (size * v) / 2^precision - 1 <= size - 1 < 2^32.
74 // For lower bound of b, note that 1 <= v, 2^16 <= size, and 16 <= precision.
75 // Therefore (size * v) / 2^precision - 1 >= 2^16 / 2^precision - 1 >= 0.
76
77 // The new interval is [base + a, base + b] = [base + a, base + b + 1).
78 base_ += a; // May overflow.
79 size_minus1_ = b - a;
80 const bool base_overflow = (base_ < a);
81
82 // The encoder has two states. Let's call them state 0 and state 1.
83 // State 0 is when base < base + size <= 2^32.
84 // State 1 is when base < 2^32 < base + size.
85 //
86 // The encoder initially starts in state 0, with base = 0, size = 2^32.
87 //
88 // TODO(sjhwang): Requires some profiling, but the encoder stays in state 0
89 // most of the time. Should optimize code for state 0.
90 //
91 // Each Encode() has up to two places where the interval changes:
92 // #1. Refine the interval. [base, base + size) -> [base + a, base + b + 1).
93 // #2. Expand interval if the new size is too small,
94 // and each change may cause a state transition.
95 //
96 // First, consider when the current state is 0.
97 //
98 // In this case, the next state after #1 is always state 0, since refining
99 // interval only shrinks the interval, therefore new_base + new_size <= 2^32.
100 //
101 // Let us explain #2.
102 //
103 // Recall that at the beginning of each Encode(), the encoder requires
104 // 2^16 < size <= 2^32. As precision <= 16, the new interval size can be as
105 // small as 1, but never zero.
106 //
107 // To keep size above 2^16, if new size is smaller than or equal to 2^16, the
108 // encoder would left-shift base and size by 16 bits: size' <- size * 2^16.
109 // Note that new size' is now in the range [2^16, 2^32].
110 //
111 // Since size is left-shifted, the same should be applied to base as well.
112 // However, after the left-shift, base will then contain 48 bits instead of 32
113 // bits. Therefore prior to the shift, The upper 16 bits in base should be
114 // stored somewhere else.
115 //
116 // If the upper 16 bits of all values in the interval were the same, i.e., if
117 // base[32:16] == (base + size - 1)[32:16], then base[32:16] can be written
118 // out to `output` string, since any further Encode() only narrows down the
119 // interval and that 16 bits would never change.
120 //
121 // If the upper 16 bits were not all the same, since this happens only when
122 // size <= 2^16, the upper 16 bits may differ only by one, i.e.,
123 // base[32:16] + 1 == (base + size - 1)[32:16]. At this stage, it is not
124 // determined yet whether base[32:16] should be written to the output or
125 // (base[32:16] + 1) should be written to the output. In this case,
126 // (base[32:16] + 1) is temporarily stored in `delay`, and base is
127 // left-shifted by 16 bits.
128 //
129 // In the latter case, the condition implies that (base // 2^16) and
130 // ((base + size - 1) // 2^16) were different. Therefore after left-shift by
131 // 16 bits, the new (base + size) is greater than 2^32, i.e., the encoder
132 // transition to state 1.
133 //
134 // ==== Summary ====
135 // To detect the current encoder state,
136 // state 0: delay == 0 iff (base mod 2^32) < (base + size) mod 2^32,
137 // state 1: delay != 0 iff (base + size) mod 2^32 <= base mod 2^32,
138 // because size <= 2^32.
139 //
140 // ==== Summary for state 0 ====
141 // 1. Interval refinement does not cause state transition.
142 // 2. Interval expansion may cause state transition, depending on the upper 16
143 // bits of base and base + size - 1.
144 //
145 // Now suppose the previous state was 1. This means that
146 // base <= 2^32 < base + size.
147 //
148 // When in state 1, an interval refinement may trigger state transition.
149 // After Encode() refines the interval, there are three possibilities:
150 // #1. base <= 2^32 < base + size (unchanged),
151 // #2. 2^32 <= base < base + size (base overflowed),
152 // #3. base < base + size <= 2^32 (base + size - 1 underflowed).
153 //
154 // In case #1, the encoder remains in state 1.
155 // In case #2 or #3, the encoder state changes to state 0.
156 //
157 // ==== State transition for interval refinement ====
158 // 1. state 0 -> state 0,
159 // 2. state 1 -> state 0 or state 1.
160 //
161 // Therefore if the new state is 1, then the previous state must have been
162 // state 1.
163 if (base_ + size_minus1_ < base_) {
164 // If statement checked if 2^32 < base + size. The new state is 1, hence the
165 // previous state was also state 1.
166 DCHECK_NE(((base_ - a) + size) >> 32, 0);
167 DCHECK_NE(delay_ & 0xFFFF, 0);
168
169 // Like in state 0, if the new size is <= 2^16, then base and size should
170 // be left-shifted by 16 bits. Combine the conditions
171 // base <= 2^32 < base + size and size <= 2^16 to conclude that
172 // base[32:16] >= 0xFFFF and (base + size - 1)[32:16] = 0x0000.
173 //
174 // Note that 2^32 - base < size, and since base is at least 0xFFFF0000,
175 // 2^16 - base[16:0] < size. Let base' and size' be the new base and size
176 // after the bit-shift. Then 2^32 - base' < size' => 2^32 < base' + size'.
177 // Therefore the encoder remains in state 1.
178 //
179 // Lastly, `delay` is modified. Conceptually, delay has to be changed to
180 // delay' <- delay * 2^16 + (base + size - 1)[32:16].
181 // Since we know above that (base + size - 1)[32:16] = 0x0000, there is no
182 // need to explicitly do the computation above, but rather store how many
183 // trailing zeros there were. For this reason, the lower 16 bits of
184 // `delay` stores the delayed value when state changed from 0 to 1, and
185 // delay[32:16] stores the # of trailing zeros (in bytes).
186 //
187 // ==== State transition for interval expansion ====
188 // 1. state 0 -> state 0 or state 1,
189 // 2. state 1 -> state 1.
190 if (size_minus1_ >> 16 == 0) {
191 DCHECK_EQ(base_ >> 16, 0xFFFF);
192 base_ <<= 16;
193 size_minus1_ <<= 16;
194 size_minus1_ |= 0xFFFF;
195 // TODO(sjhwang): It is possible that for very long input, delay
196 // overflow during below. If overflow is detected, this delay is too
197 // long the encoder should forcefully move to state 0. In such case,
198 // base can be raised to 2^32 (force case #2), or (base + size) can be
199 // lowered to 2^32 (force case #3), depending on which transition
200 // keeps size larger.
201 CHECK_LT(delay_, static_cast<uint64>(1) << 62);
202 delay_ += 0x20000; // Two more bytes of zeros. Check overflow?
203 }
204 return;
205 }
206
207 // If reached here, the current state is 0.
208 // First handle the case when the previous state was state 1.
209 if (delay_ != 0) {
210 // In case #2 or #3, the encoder state changes to state 0. Recall that when
211 // the encoder state changed from state 0 to state 1, the top 16 bits of
212 // (base + size - 1) was temporarily stored in `delay`, because the output
213 // could be either (delay - 1) or (delay).
214 //
215 // And from above, the delayed value encoded in `delay` is
216 // delay' <- delay[16:0] * 2^(8 * delay[MAX:16])
217 //
218 // In case #2, the interval moved below 2^32. So (delay' - 1) is the
219 // converged value after interval refinements. Write out
220 // (delay[16:0] - 1) and write (8 * delay[MAX:16]) bytes of 0xFF.
221 //
222 // In case #3, the interval moved above 2^32. So delay' is the converged
223 // value after interval refinement. Write out delay[16:0] and write
224 // (8 * delay[MAX:16]) bytes of 0x00.
225 if (base_overflow) {
226 // Case #2.
227 DCHECK_NE((static_cast<uint64>(base_ - a) + a) >> 32, 0);
228 sink->push_back(static_cast<char>(delay_ >> 8));
229 sink->push_back(static_cast<char>(delay_ >> 0));
230 sink->append(delay_ >> 16, static_cast<char>(0));
231 } else {
232 // Case #3.
233 DCHECK_EQ(static_cast<uint64>(base_ + size_minus1_) >> 32, 0);
234 --delay_;
235 sink->push_back(static_cast<char>(delay_ >> 8));
236 sink->push_back(static_cast<char>(delay_ >> 0));
237 sink->append(delay_ >> 16, static_cast<char>(0xFF));
238 }
239 // Reset to state 0.
240 delay_ = 0;
241 }
242
243 if (size_minus1_ >> 16 == 0) {
244 const uint32 top = base_ >> 16;
245
246 base_ <<= 16;
247 size_minus1_ <<= 16;
248 size_minus1_ |= 0xFFFF;
249
250 if (base_ <= base_ + size_minus1_) {
251 // Still in state 0. Write the top 16 bits.
252 sink->push_back(static_cast<char>(top >> 8));
253 sink->push_back(static_cast<char>(top));
254 } else {
255 // New state is 1.
256 DCHECK_LT(top, 0xFFFF);
257 delay_ = top + 1;
258 }
259 }
260 }
261
Finalize(string * sink)262 void RangeEncoder::Finalize(string* sink) {
263 // Finalize the encode by writing out any number in the interval
264 // [base, base + size).
265 //
266 // Trailing zeros are not explicitly written out as decoder can fill in zeros
267 // by default.
268 if (delay_ != 0) {
269 // The last state was state 1. Since base < 2^32 < base + size, pick 2^32
270 // (state 1, case #3).
271 // NOTE: It is a bit difficult to trigger this code path on purpose.
272 // TODO(sjhwang): Find a way to trigger this code path for test coverage.
273 sink->push_back(static_cast<char>(delay_ >> 8));
274 if ((delay_ & 0xFF) != 0) {
275 sink->push_back(static_cast<char>(delay_));
276 }
277 } else if (base_ != 0) {
278 // If base == 0, then pick 0 from [base, base + size) and no zeros are
279 // explicitly written.
280 //
281 // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the
282 // next multiple of 2^16. As 2^16 < size, this value should be in the
283 // interval [base, base + size).
284 const uint32 mid = ((base_ - 1) >> 16) + 1;
285 DCHECK_EQ(mid & 0xFFFF, mid);
286 sink->push_back(static_cast<char>(mid >> 8));
287 if ((mid & 0xFF) != 0) {
288 sink->push_back(static_cast<char>(mid >> 0));
289 }
290 }
291
292 base_ = 0;
293 size_minus1_ = std::numeric_limits<uint32>::max();
294 delay_ = 0;
295 }
296
RangeDecoder(const string & source,int precision)297 RangeDecoder::RangeDecoder(const string& source, int precision)
298 : current_(source.begin()),
299 begin_(source.begin()),
300 end_(source.end()),
301 precision_(precision) {
302 CHECK_LE(precision, 16);
303
304 Read16BitValue();
305 Read16BitValue();
306 }
307
Decode(tensorflow::gtl::ArraySlice<int32> cdf)308 int32 RangeDecoder::Decode(tensorflow::gtl::ArraySlice<int32> cdf) {
309 const uint64 size = static_cast<uint64>(size_minus1_) + 1;
310 const uint64 offset =
311 ((static_cast<uint64>(value_ - base_) + 1) << precision_) - 1;
312
313 // This is similar to std::lower_range() with std::less_equal as comparison.
314 // After the binary search, `pv` points to the smallest number v that
315 // satisfies offset < (size * v) / 2^precision.
316
317 // Assumes that cdf[0] == 0. Therefore (size * cdf[0]) / 2^precision is always
318 // less than or equal to offset.
319 const int32* pv = cdf.data() + 1;
320 // `len` can be cdf.size() - 2 if there is guarantee that the last element of
321 // cdf is 2^precision.
322 auto len = cdf.size() - 1;
323 DCHECK_GT(len, 0);
324
325 do {
326 const auto half = len / 2;
327 const int32* mid = pv + half;
328 DCHECK_GE(*mid, 0);
329 DCHECK_LE(*mid, 1 << precision_);
330 if (size * static_cast<uint64>(*mid) <= offset) {
331 pv = mid + 1;
332 len -= half + 1;
333 } else {
334 len = half;
335 }
336 } while (len > 0);
337
338 // If (size * v) / 2^precision <= offset for all v in cdf, then pv points to
339 // one after the last element of cdf. That is a decoding error.
340 //
341 // TODO(sjhwang): Consider returning -1 to indicate error. Or start len =
342 // cdf.size() - 2 instead and give up detecting this error.
343 CHECK_LT(pv, cdf.data() + cdf.size());
344
345 const uint32 a = (size * static_cast<uint64>(*(pv - 1))) >> precision_;
346 const uint32 b = ((size * static_cast<uint64>(*pv)) >> precision_) - 1;
347 DCHECK_LE(a, offset >> precision_);
348 DCHECK_LE(offset >> precision_, b);
349
350 base_ += a;
351 size_minus1_ = b - a;
352
353 if (size_minus1_ >> 16 == 0) {
354 base_ <<= 16;
355 size_minus1_ <<= 16;
356 size_minus1_ |= 0xFFFF;
357
358 Read16BitValue();
359 }
360
361 return pv - cdf.data() - 1;
362 }
363
Read16BitValue()364 void RangeDecoder::Read16BitValue() {
365 value_ <<= 8;
366 if (current_ != end_) {
367 value_ |= static_cast<uint8>(*current_++);
368 }
369 value_ <<= 8;
370 if (current_ != end_) {
371 value_ |= static_cast<uint8>(*current_++);
372 }
373 }
374 } // namespace tensorflow
375