• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Range coder implementation, based on [1].
17 //
18 // [1] G. N. N. Martin, "Range coding: an algorithm for removing redundancy from
19 // a digitised message", presented to the Video & Data Recording Conference,
20 // held in Southampton, July 24-27, 1979.
21 //
22 #include "tensorflow/contrib/coder/kernels/range_coder.h"
23 
24 #include <limits>
25 #include <string>
26 
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
RangeEncoder(int precision)32 RangeEncoder::RangeEncoder(int precision) : precision_(precision) {
33   CHECK_GT(precision, 0);
34   CHECK_LE(precision, 16);
35 }
36 
Encode(int32 lower,int32 upper,string * sink)37 void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) {
38   // Input requirement: 0 <= lower < upper <= 2^precision.
39   DCHECK_LE(0, lower);
40   DCHECK_LT(lower, upper);
41   DCHECK_LE(upper, 1 << precision_);
42 
43   // `base` and `size` represent a half-open interval [base, base + size).
44   // Loop invariant: 2^16 <= size <= 2^32.
45   //
46   // Note that keeping size above 2^16 is important. Since the interval sizes
47   // are quantized to up to 16 bits, the smallest interval size the encode may
48   // handle is 2^-16. If size is smaller than 2^16, a small interval input may
49   // collapse the encoder range into an empty interval.
50   const uint64 size = static_cast<uint64>(size_minus1_) + 1;
51   DCHECK_NE(size >> 16, 0);
52 
53   // For short notation, let u := lower and v := upper.
54   //
55   // The input u, v represents a half-open interval [u, v) / 2^precision.
56   // This narrows the current interval roughly to
57   // [base + (size * u) / 2^precision, base + (size * v) / 2^precision).
58   //
59   // TODO(sjhwang): Try rounding if it helps improve compression ratio, at the
60   // expense of more operations. In the test using Zipf distribution, the
61   // overhead over the theoretical compression ratio was ~0.01%.
62   // NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u`
63   // can be rewritten as `(size - 1) * u + u` and all the computation can be
64   // done in 32-bit mode. If 32-bit multiply is faster, then rewrite.
65   const uint32 a = (size * static_cast<uint64>(lower)) >> precision_;
66   const uint32 b = ((size * static_cast<uint64>(upper)) >> precision_) - 1;
67   DCHECK_LE(a, b);
68 
69   // Let's confirm the RHS of a, b fit in uint32 type.
70   // Recall that 0 <= u < 2^precision, and size <= 2^32. Therefore
71   //   (size * u) / 2^precision < size <= 2^32,
72   // and the value of a fits in uint32 type. Similarly, since v <= 2^precision,
73   //   (size * v) / 2^precision - 1 <= size - 1 < 2^32.
74   // For lower bound of b, note that 1 <= v, 2^16 <= size, and 16 <= precision.
75   // Therefore (size * v) / 2^precision - 1 >= 2^16 / 2^precision - 1 >= 0.
76 
77   // The new interval is [base + a, base + b] = [base + a, base + b + 1).
78   base_ += a;  // May overflow.
79   size_minus1_ = b - a;
80   const bool base_overflow = (base_ < a);
81 
82   // The encoder has two states. Let's call them state 0 and state 1.
83   // State 0 is when base < base + size <= 2^32.
84   // State 1 is when base < 2^32 < base + size.
85   //
86   // The encoder initially starts in state 0, with base = 0, size = 2^32.
87   //
88   // TODO(sjhwang): Requires some profiling, but the encoder stays in state 0
89   // most of the time. Should optimize code for state 0.
90   //
91   // Each Encode() has up to two places where the interval changes:
92   //   #1. Refine the interval. [base, base + size) -> [base + a, base + b + 1).
93   //   #2. Expand interval if the new size is too small,
94   // and each change may cause a state transition.
95   //
96   // First, consider when the current state is 0.
97   //
98   // In this case, the next state after #1 is always state 0, since refining
99   // interval only shrinks the interval, therefore new_base + new_size <= 2^32.
100   //
101   // Let us explain #2.
102   //
103   // Recall that at the beginning of each Encode(), the encoder requires
104   // 2^16 < size <= 2^32. As precision <= 16, the new interval size can be as
105   // small as 1, but never zero.
106   //
107   // To keep size above 2^16, if new size is smaller than or equal to 2^16, the
108   // encoder would left-shift base and size by 16 bits: size' <- size * 2^16.
109   // Note that new size' is now in the range [2^16, 2^32].
110   //
111   // Since size is left-shifted, the same should be applied to base as well.
112   // However, after the left-shift, base will then contain 48 bits instead of 32
113   // bits. Therefore prior to the shift, The upper 16 bits in base should be
114   // stored somewhere else.
115   //
116   // If the upper 16 bits of all values in the interval were the same, i.e., if
117   // base[32:16] == (base + size - 1)[32:16], then base[32:16] can be written
118   // out to `output` string, since any further Encode() only narrows down the
119   // interval and that 16 bits would never change.
120   //
121   // If the upper 16 bits were not all the same, since this happens only when
122   // size <= 2^16, the upper 16 bits may differ only by one, i.e.,
123   // base[32:16] + 1 == (base + size - 1)[32:16]. At this stage, it is not
124   // determined yet whether base[32:16] should be written to the output  or
125   // (base[32:16] + 1) should be written to the output. In this case,
126   // (base[32:16] + 1) is temporarily stored in `delay`, and base is
127   // left-shifted by 16 bits.
128   //
129   // In the latter case, the condition implies that (base // 2^16) and
130   // ((base + size - 1) // 2^16) were different. Therefore after left-shift by
131   // 16 bits, the new (base + size) is greater than 2^32, i.e., the encoder
132   // transition to state 1.
133   //
134   // ==== Summary ====
135   // To detect the current encoder state,
136   //   state 0: delay == 0 iff (base mod 2^32) < (base + size) mod 2^32,
137   //   state 1: delay != 0 iff (base + size) mod 2^32 <= base mod 2^32,
138   // because size <= 2^32.
139   //
140   // ==== Summary for state 0 ====
141   // 1. Interval refinement does not cause state transition.
142   // 2. Interval expansion may cause state transition, depending on the upper 16
143   // bits of base and base + size - 1.
144   //
145   // Now suppose the previous state was 1. This means that
146   // base <= 2^32 < base + size.
147   //
148   // When in state 1, an interval refinement may trigger state transition.
149   // After Encode() refines the interval, there are three possibilities:
150   //   #1. base <= 2^32 < base + size (unchanged),
151   //   #2. 2^32 <= base < base + size (base overflowed),
152   //   #3. base < base + size <= 2^32 (base + size - 1 underflowed).
153   //
154   // In case #1, the encoder remains in state 1.
155   // In case #2 or #3, the encoder state changes to state 0.
156   //
157   // ==== State transition for interval refinement ====
158   // 1. state 0 -> state 0,
159   // 2. state 1 -> state 0 or state 1.
160   //
161   // Therefore if the new state is 1, then the previous state must have been
162   // state 1.
163   if (base_ + size_minus1_ < base_) {
164     // If statement checked if 2^32 < base + size. The new state is 1, hence the
165     // previous state was also state 1.
166     DCHECK_NE(((base_ - a) + size) >> 32, 0);
167     DCHECK_NE(delay_ & 0xFFFF, 0);
168 
169     // Like in state 0, if the new size is <= 2^16, then base and size should
170     // be left-shifted by 16 bits. Combine the conditions
171     // base <= 2^32 < base + size and size <= 2^16 to conclude that
172     // base[32:16] >= 0xFFFF and (base + size - 1)[32:16] = 0x0000.
173     //
174     // Note that 2^32 - base < size, and since base is at least 0xFFFF0000,
175     // 2^16 - base[16:0] < size. Let base' and size' be the new base and size
176     // after the bit-shift. Then 2^32 - base' < size' => 2^32 < base' + size'.
177     // Therefore the encoder remains in state 1.
178     //
179     // Lastly, `delay` is modified. Conceptually, delay has to be changed to
180     //   delay' <- delay * 2^16 + (base + size - 1)[32:16].
181     // Since we know above that (base + size - 1)[32:16] = 0x0000, there is no
182     // need to explicitly do the computation above, but rather store how many
183     // trailing zeros there were. For this reason, the lower 16 bits of
184     // `delay` stores the delayed value when state changed from 0 to 1, and
185     // delay[32:16] stores the # of trailing zeros (in bytes).
186     //
187     // ==== State transition for interval expansion ====
188     // 1. state 0 -> state 0 or state 1,
189     // 2. state 1 -> state 1.
190     if (size_minus1_ >> 16 == 0) {
191       DCHECK_EQ(base_ >> 16, 0xFFFF);
192       base_ <<= 16;
193       size_minus1_ <<= 16;
194       size_minus1_ |= 0xFFFF;
195       // TODO(sjhwang): It is possible that for very long input, delay
196       // overflow during below. If overflow is detected, this delay is too
197       // long the encoder should forcefully move to state 0. In such case,
198       // base can be raised to 2^32 (force case #2), or (base + size) can be
199       // lowered to 2^32 (force case #3), depending on which transition
200       // keeps size larger.
201       CHECK_LT(delay_, static_cast<uint64>(1) << 62);
202       delay_ += 0x20000;  // Two more bytes of zeros. Check overflow?
203     }
204     return;
205   }
206 
207   // If reached here, the current state is 0.
208   // First handle the case when the previous state was state 1.
209   if (delay_ != 0) {
210     // In case #2 or #3, the encoder state changes to state 0. Recall that when
211     // the encoder state changed from state 0 to state 1, the top 16 bits of
212     // (base + size - 1) was temporarily stored in `delay`, because the output
213     // could be either (delay - 1) or (delay).
214     //
215     // And from above, the delayed value encoded in `delay` is
216     //   delay' <- delay[16:0] * 2^(8 * delay[MAX:16])
217     //
218     // In case #2, the interval moved below 2^32. So (delay' - 1) is the
219     // converged value after interval refinements. Write out
220     // (delay[16:0] - 1) and write (8 * delay[MAX:16]) bytes of 0xFF.
221     //
222     // In case #3, the interval moved above 2^32. So delay' is the converged
223     // value after interval refinement. Write out delay[16:0] and write
224     // (8 * delay[MAX:16]) bytes of 0x00.
225     if (base_overflow) {
226       // Case #2.
227       DCHECK_NE((static_cast<uint64>(base_ - a) + a) >> 32, 0);
228       sink->push_back(static_cast<char>(delay_ >> 8));
229       sink->push_back(static_cast<char>(delay_ >> 0));
230       sink->append(delay_ >> 16, static_cast<char>(0));
231     } else {
232       // Case #3.
233       DCHECK_EQ(static_cast<uint64>(base_ + size_minus1_) >> 32, 0);
234       --delay_;
235       sink->push_back(static_cast<char>(delay_ >> 8));
236       sink->push_back(static_cast<char>(delay_ >> 0));
237       sink->append(delay_ >> 16, static_cast<char>(0xFF));
238     }
239     // Reset to state 0.
240     delay_ = 0;
241   }
242 
243   if (size_minus1_ >> 16 == 0) {
244     const uint32 top = base_ >> 16;
245 
246     base_ <<= 16;
247     size_minus1_ <<= 16;
248     size_minus1_ |= 0xFFFF;
249 
250     if (base_ <= base_ + size_minus1_) {
251       // Still in state 0. Write the top 16 bits.
252       sink->push_back(static_cast<char>(top >> 8));
253       sink->push_back(static_cast<char>(top));
254     } else {
255       // New state is 1.
256       DCHECK_LT(top, 0xFFFF);
257       delay_ = top + 1;
258     }
259   }
260 }
261 
Finalize(string * sink)262 void RangeEncoder::Finalize(string* sink) {
263   // Finalize the encode by writing out any number in the interval
264   // [base, base + size).
265   //
266   // Trailing zeros are not explicitly written out as decoder can fill in zeros
267   // by default.
268   if (delay_ != 0) {
269     // The last state was state 1. Since base < 2^32 < base + size, pick 2^32
270     // (state 1, case #3).
271     // NOTE: It is a bit difficult to trigger this code path on purpose.
272     // TODO(sjhwang): Find a way to trigger this code path for test coverage.
273     sink->push_back(static_cast<char>(delay_ >> 8));
274     if ((delay_ & 0xFF) != 0) {
275       sink->push_back(static_cast<char>(delay_));
276     }
277   } else if (base_ != 0) {
278     // If base == 0, then pick 0 from [base, base + size) and no zeros are
279     // explicitly written.
280     //
281     // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the
282     // next multiple of 2^16. As 2^16 < size, this value should be in the
283     // interval [base, base + size).
284     const uint32 mid = ((base_ - 1) >> 16) + 1;
285     DCHECK_EQ(mid & 0xFFFF, mid);
286     sink->push_back(static_cast<char>(mid >> 8));
287     if ((mid & 0xFF) != 0) {
288       sink->push_back(static_cast<char>(mid >> 0));
289     }
290   }
291 
292   base_ = 0;
293   size_minus1_ = std::numeric_limits<uint32>::max();
294   delay_ = 0;
295 }
296 
RangeDecoder(const string & source,int precision)297 RangeDecoder::RangeDecoder(const string& source, int precision)
298     : current_(source.begin()),
299       begin_(source.begin()),
300       end_(source.end()),
301       precision_(precision) {
302   CHECK_LE(precision, 16);
303 
304   Read16BitValue();
305   Read16BitValue();
306 }
307 
Decode(tensorflow::gtl::ArraySlice<int32> cdf)308 int32 RangeDecoder::Decode(tensorflow::gtl::ArraySlice<int32> cdf) {
309   const uint64 size = static_cast<uint64>(size_minus1_) + 1;
310   const uint64 offset =
311       ((static_cast<uint64>(value_ - base_) + 1) << precision_) - 1;
312 
313   // This is similar to std::lower_range() with std::less_equal as comparison.
314   // After the binary search, `pv` points to the smallest number v that
315   // satisfies offset < (size * v) / 2^precision.
316 
317   // Assumes that cdf[0] == 0. Therefore (size * cdf[0]) / 2^precision is always
318   // less than or equal to offset.
319   const int32* pv = cdf.data() + 1;
320   // `len` can be cdf.size() - 2 if there is guarantee that the last element of
321   // cdf is 2^precision.
322   auto len = cdf.size() - 1;
323   DCHECK_GT(len, 0);
324 
325   do {
326     const auto half = len / 2;
327     const int32* mid = pv + half;
328     DCHECK_GE(*mid, 0);
329     DCHECK_LE(*mid, 1 << precision_);
330     if (size * static_cast<uint64>(*mid) <= offset) {
331       pv = mid + 1;
332       len -= half + 1;
333     } else {
334       len = half;
335     }
336   } while (len > 0);
337 
338   // If (size * v) / 2^precision <= offset for all v in cdf, then pv points to
339   // one after the last element of cdf. That is a decoding error.
340   //
341   // TODO(sjhwang): Consider returning -1 to indicate error. Or start len =
342   // cdf.size() - 2 instead and give up detecting this error.
343   CHECK_LT(pv, cdf.data() + cdf.size());
344 
345   const uint32 a = (size * static_cast<uint64>(*(pv - 1))) >> precision_;
346   const uint32 b = ((size * static_cast<uint64>(*pv)) >> precision_) - 1;
347   DCHECK_LE(a, offset >> precision_);
348   DCHECK_LE(offset >> precision_, b);
349 
350   base_ += a;
351   size_minus1_ = b - a;
352 
353   if (size_minus1_ >> 16 == 0) {
354     base_ <<= 16;
355     size_minus1_ <<= 16;
356     size_minus1_ |= 0xFFFF;
357 
358     Read16BitValue();
359   }
360 
361   return pv - cdf.data() - 1;
362 }
363 
Read16BitValue()364 void RangeDecoder::Read16BitValue() {
365   value_ <<= 8;
366   if (current_ != end_) {
367     value_ |= static_cast<uint8>(*current_++);
368   }
369   value_ <<= 8;
370   if (current_ != end_) {
371     value_ |= static_cast<uint8>(*current_++);
372   }
373 }
374 }  // namespace tensorflow
375