1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
18
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
29
30 namespace xla {
31
32 constexpr int64 kBitsOfByte = 8;
33
34 // Represents the range used for quantization
35 struct QuantizedRange {
36 QuantizedRange() = default;
QuantizedRangeQuantizedRange37 QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {}
38
39 bool operator==(const QuantizedRange& rhs) const {
40 return this->min == rhs.min && this->max == rhs.max;
41 }
42
43 bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); }
44
45 tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f);
46 tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f);
47 };
48
49 template <typename T>
PackToUint32(absl::Span<const T> input)50 inline std::vector<uint32> PackToUint32(absl::Span<const T> input) {
51 const int64 kElementsPerPack = sizeof(uint32) / sizeof(T);
52 const int64 input_size = input.size();
53 const int64 output_size = CeilOfRatio(input_size, kElementsPerPack);
54
55 std::vector<uint32> output_vec;
56 constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte;
57
58 for (int64 i = 0; i < output_size; i++) {
59 uint32 result = 0;
60 for (int64 p = 0; p < kElementsPerPack; p++) {
61 int64 index = i * kElementsPerPack + p;
62 if (index < input_size) {
63 int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1);
64 result |= (input[index] << total_shift_bits);
65 }
66 }
67 output_vec.push_back(result);
68 }
69
70 return output_vec;
71 }
72
73 // Dequantize the quantized input of packed uint32 to bfloat16.
74 // Only uint8 or uint16 is supported for the original unpacked input.
75 // Returns a tensor of shape [d0,..., dn * unpack_size] if
76 // input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T).
77 // If transpose_output is true, will return a tensor of shape
78 // [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when
79 // input's rank higher than 1. The input needs to be transposed to use
80 // transpose_output feature.
81 template <typename T>
82 inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
83 absl::string_view mode_string = "MIN_COMBINED",
84 bool transpose_output = false) {
85 XlaBuilder* const builder = input.builder();
86 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
87 float half_range =
88 !std::is_signed<T>::value
89 ? 0.0f
90 : (static_cast<float>(std::numeric_limits<T>::max()) -
91 std::numeric_limits<T>::min() + 1) /
92 2.0f;
93 const int64 unpack_size = sizeof(uint32) / sizeof(T);
94 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input));
95
96 auto element_type = shape.element_type();
97 if (element_type != U32) {
98 return InvalidArgument(
99 "Only U32 is supported for input type of xla::Dequantize Op.");
100 }
101
102 // Broadcast the input to [unpack_size, d0, ..., dn] if input size is
103 // [d0, ..., dn].
104 auto broadcast_input = Broadcast(input, {unpack_size});
105
106 XlaOp iota_r1 = Iota(builder, U32, unpack_size);
107 // Highest significant bytes needs to shift more bytes than lower
108 // significant bytes.
109 XlaOp shift_bytes =
110 xla::ConstantR0<uint32>(builder, unpack_size - 1) - iota_r1;
111
112 const int bytes_of_type = sizeof(T) / sizeof(uint8);
113 std::vector<uint32> shift_vec(unpack_size, kBitsOfByte * bytes_of_type);
114 XlaOp shift_bits =
115 shift_bytes * xla::ConstantR1<uint32>(builder, shift_vec);
116
117 // Make bit_mask for different data type T.
118 uint32 bit_mask = 0x00000000;
119 for (int i = 0; i < bytes_of_type; i++) {
120 bit_mask <<= kBitsOfByte;
121 bit_mask |= 0x000000ff;
122 }
123
124 std::vector<int64> shift_transpose_dimensions(shape.dimensions_size());
125 std::iota(shift_transpose_dimensions.begin(),
126 shift_transpose_dimensions.end(), 0);
127 shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1,
128 shape.dimensions_size());
129
130 // Shift the input by sizeof(T) bytes and apply bit_mask to unpack.
131 XlaOp shifted_input = ShiftRightLogical(
132 broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()),
133 shift_transpose_dimensions));
134 XlaOp unpack_input =
135 And(shifted_input, xla::ConstantR0<uint32>(builder, bit_mask));
136
137 XlaOp result;
138
139 if (mode_string == "MIN_COMBINED") {
140 const tensorflow::bfloat16 scale_factor =
141 (range.max - range.min) /
142 (static_cast<tensorflow::bfloat16>(std::numeric_limits<T>::max() -
143 std::numeric_limits<T>::min()));
144 // result = bfloat16(input + half_range) * scale_factor + range.min
145 XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16);
146 XlaOp half_range_bf16 = xla::ConstantR0<tensorflow::bfloat16>(
147 builder, static_cast<bfloat16>(half_range));
148 XlaOp sum = unpack_input_bf16 + half_range_bf16;
149
150 result =
151 sum * xla::ConstantR0<tensorflow::bfloat16>(builder, scale_factor) +
152 xla::ConstantR0<tensorflow::bfloat16>(builder, range.min);
153 } else {
154 // TODO(wangtao): support other modes.
155 return InvalidArgument(
156 "Only MIN_COMBINED mode is supported in xla::Dequantize Op.");
157 }
158
159 std::vector<int64> transpose_dimensions(shape.dimensions_size());
160 std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1);
161 std::reverse(transpose_dimensions.begin(), transpose_dimensions.end());
162 transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0);
163
164 // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0].
165 XlaOp transposed_result = Transpose(result, transpose_dimensions);
166
167 // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0].
168 XlaOp reshaped_result = Collapse(transposed_result, {0, 1});
169
170 // Return the transpose result if transpose_output is true.
171 if (transpose_output) {
172 return reshaped_result;
173 }
174
175 // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size].
176 std::vector<int64> result_dimensions(shape.dimensions_size());
177 std::iota(result_dimensions.begin(), result_dimensions.end(), 0);
178 std::reverse(result_dimensions.begin(), result_dimensions.end());
179
180 return Transpose(reshaped_result, result_dimensions);
181 });
182 }
183
184 } // namespace xla
185
186 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
187