• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
18 
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
29 
30 namespace xla {
31 
32 constexpr int64 kBitsOfByte = 8;
33 
34 // Represents the range used for quantization
35 struct QuantizedRange {
36   QuantizedRange() = default;
QuantizedRangeQuantizedRange37   QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {}
38 
39   bool operator==(const QuantizedRange& rhs) const {
40     return this->min == rhs.min && this->max == rhs.max;
41   }
42 
43   bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); }
44 
45   tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f);
46   tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f);
47 };
48 
49 template <typename T>
PackToUint32(absl::Span<const T> input)50 inline std::vector<uint32> PackToUint32(absl::Span<const T> input) {
51   const int64 kElementsPerPack = sizeof(uint32) / sizeof(T);
52   const int64 input_size = input.size();
53   const int64 output_size = CeilOfRatio(input_size, kElementsPerPack);
54 
55   std::vector<uint32> output_vec;
56   constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte;
57 
58   for (int64 i = 0; i < output_size; i++) {
59     uint32 result = 0;
60     for (int64 p = 0; p < kElementsPerPack; p++) {
61       int64 index = i * kElementsPerPack + p;
62       if (index < input_size) {
63         int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1);
64         result |= (input[index] << total_shift_bits);
65       }
66     }
67     output_vec.push_back(result);
68   }
69 
70   return output_vec;
71 }
72 
73 // Dequantize the quantized input of packed uint32 to bfloat16.
74 // Only uint8 or uint16 is supported for the original unpacked input.
75 // Returns a tensor of shape [d0,..., dn * unpack_size] if
76 // input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T).
77 // If transpose_output is true, will return a tensor of shape
78 // [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when
79 // input's rank higher than 1. The input needs to be transposed to use
80 // transpose_output feature.
81 template <typename T>
82 inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
83                         absl::string_view mode_string = "MIN_COMBINED",
84                         bool transpose_output = false) {
85   XlaBuilder* const builder = input.builder();
86   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
87     float half_range =
88         !std::is_signed<T>::value
89             ? 0.0f
90             : (static_cast<float>(std::numeric_limits<T>::max()) -
91                std::numeric_limits<T>::min() + 1) /
92                   2.0f;
93     const int64 unpack_size = sizeof(uint32) / sizeof(T);
94     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input));
95 
96     auto element_type = shape.element_type();
97     if (element_type != U32) {
98       return InvalidArgument(
99           "Only U32 is supported for input type of xla::Dequantize Op.");
100     }
101 
102     // Broadcast the input to [unpack_size, d0, ..., dn] if input size is
103     // [d0, ..., dn].
104     auto broadcast_input = Broadcast(input, {unpack_size});
105 
106     XlaOp iota_r1 = Iota(builder, U32, unpack_size);
107     // Highest significant bytes needs to shift more bytes than lower
108     // significant bytes.
109     XlaOp shift_bytes =
110         xla::ConstantR0<uint32>(builder, unpack_size - 1) - iota_r1;
111 
112     const int bytes_of_type = sizeof(T) / sizeof(uint8);
113     std::vector<uint32> shift_vec(unpack_size, kBitsOfByte * bytes_of_type);
114     XlaOp shift_bits =
115         shift_bytes * xla::ConstantR1<uint32>(builder, shift_vec);
116 
117     // Make bit_mask for different data type T.
118     uint32 bit_mask = 0x00000000;
119     for (int i = 0; i < bytes_of_type; i++) {
120       bit_mask <<= kBitsOfByte;
121       bit_mask |= 0x000000ff;
122     }
123 
124     std::vector<int64> shift_transpose_dimensions(shape.dimensions_size());
125     std::iota(shift_transpose_dimensions.begin(),
126               shift_transpose_dimensions.end(), 0);
127     shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1,
128                                       shape.dimensions_size());
129 
130     // Shift the input by sizeof(T) bytes and apply bit_mask to unpack.
131     XlaOp shifted_input = ShiftRightLogical(
132         broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()),
133                                    shift_transpose_dimensions));
134     XlaOp unpack_input =
135         And(shifted_input, xla::ConstantR0<uint32>(builder, bit_mask));
136 
137     XlaOp result;
138 
139     if (mode_string == "MIN_COMBINED") {
140       const tensorflow::bfloat16 scale_factor =
141           (range.max - range.min) /
142           (static_cast<tensorflow::bfloat16>(std::numeric_limits<T>::max() -
143                                              std::numeric_limits<T>::min()));
144       // result = bfloat16(input + half_range) * scale_factor + range.min
145       XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16);
146       XlaOp half_range_bf16 = xla::ConstantR0<tensorflow::bfloat16>(
147           builder, static_cast<bfloat16>(half_range));
148       XlaOp sum = unpack_input_bf16 + half_range_bf16;
149 
150       result =
151           sum * xla::ConstantR0<tensorflow::bfloat16>(builder, scale_factor) +
152           xla::ConstantR0<tensorflow::bfloat16>(builder, range.min);
153     } else {
154       // TODO(wangtao): support other modes.
155       return InvalidArgument(
156           "Only MIN_COMBINED mode is supported in xla::Dequantize Op.");
157     }
158 
159     std::vector<int64> transpose_dimensions(shape.dimensions_size());
160     std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1);
161     std::reverse(transpose_dimensions.begin(), transpose_dimensions.end());
162     transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0);
163 
164     // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0].
165     XlaOp transposed_result = Transpose(result, transpose_dimensions);
166 
167     // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0].
168     XlaOp reshaped_result = Collapse(transposed_result, {0, 1});
169 
170     // Return the transpose result if transpose_output is true.
171     if (transpose_output) {
172       return reshaped_result;
173     }
174 
175     // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size].
176     std::vector<int64> result_dimensions(shape.dimensions_size());
177     std::iota(result_dimensions.begin(), result_dimensions.end(), 0);
178     std::reverse(result_dimensions.begin(), result_dimensions.end());
179 
180     return Transpose(reshaped_result, result_dimensions);
181   });
182 }
183 
184 }  // namespace xla
185 
186 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
187