1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17
18 #include <cmath>
19
20 #include "absl/base/casts.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/math.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/util.h"
25
26 namespace xla {
27 namespace {
28
29 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)30 XlaOp RotateLeftU32(XlaOp v, int distance) {
31 return (v << ConstantR0<uint32>(v.builder(), distance)) |
32 ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
33 }
34
35 } // namespace
36
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)37 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
38 XlaBuilder* builder = input[0].builder();
39 key[0] = BitcastConvertType(key[0], U32);
40 key[1] = BitcastConvertType(key[1], U32);
41
42 // Rotation distances specified by the Threefry2x32 algorithm.
43 constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
44 ThreeFry2x32State x;
45
46 std::array<XlaOp, 3> ks;
47 // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
48 ks[2] = ConstantR0<uint32>(builder, 0x1BD11BDA);
49 for (int i = 0; i < 2; ++i) {
50 ks[i] = key[i];
51 x[i] = input[i];
52 ks[2] = ks[2] ^ key[i];
53 }
54
55 x[0] = x[0] + ks[0];
56 x[1] = x[1] + ks[1];
57
58 // Performs a single round of the Threefry2x32 algorithm, with a rotation
59 // amount 'rotation'.
60 auto round = [](ThreeFry2x32State v, int rotation) {
61 v[0] = v[0] + v[1];
62 v[1] = RotateLeftU32(v[1], rotation);
63 v[1] = v[0] ^ v[1];
64 return v;
65 };
66
67 // There are no known statistical flaws with 13 rounds of Threefry2x32.
68 // We are conservative and use 20 rounds.
69 x = round(x, rotations[0]);
70 x = round(x, rotations[1]);
71 x = round(x, rotations[2]);
72 x = round(x, rotations[3]);
73 x[0] = x[0] + ks[1];
74 x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 1);
75
76 x = round(x, rotations[4]);
77 x = round(x, rotations[5]);
78 x = round(x, rotations[6]);
79 x = round(x, rotations[7]);
80 x[0] = x[0] + ks[2];
81 x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 2);
82
83 x = round(x, rotations[0]);
84 x = round(x, rotations[1]);
85 x = round(x, rotations[2]);
86 x = round(x, rotations[3]);
87 x[0] = x[0] + ks[0];
88 x[1] = x[1] + ks[1] + ConstantR0<uint32>(builder, 3);
89
90 x = round(x, rotations[4]);
91 x = round(x, rotations[5]);
92 x = round(x, rotations[6]);
93 x = round(x, rotations[7]);
94 x[0] = x[0] + ks[1];
95 x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 4);
96
97 x = round(x, rotations[0]);
98 x = round(x, rotations[1]);
99 x = round(x, rotations[2]);
100 x = round(x, rotations[3]);
101 x[0] = x[0] + ks[2];
102 x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 5);
103
104 return x;
105 }
106
107 // Returns the inputs with unique counter values for ThreeFry2x32.
GetInputs(const int64 size,XlaBuilder * builder)108 ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
109 ThreeFry2x32State inputs;
110 inputs[0] = Iota(builder, U32, size);
111 inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
112 return inputs;
113 }
114
StatelessRngUniformU32(std::array<XlaOp,2> key,const Shape & shape)115 XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
116 XlaBuilder* builder = key[0].builder();
117 const int64 size = ShapeUtil::ElementsIn(shape);
118 const int64 half_size = CeilOfRatio<int64>(size, 2);
119 const bool size_is_odd = (half_size * 2 != size);
120 ThreeFry2x32State inputs = GetInputs(half_size, builder);
121 ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
122 if (size_is_odd) {
123 outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
124 }
125 auto result = ConcatInDim(builder, outputs, 0);
126 return Reshape(result, AsInt64Slice(shape.dimensions()));
127 }
128
Uint64ToUint32s(XlaOp u64)129 ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
130 auto builder = u64.builder();
131 auto const32 = ConstantR0WithType(builder, U64, 32);
132 auto fst = ConvertElementType(u64, U32);
133 auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
134 return {fst, snd};
135 }
136
Uint32sToUint64(ThreeFry2x32State u32s)137 XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
138 auto builder = u32s[0].builder();
139 return ConvertElementType(u32s[0], U64) |
140 ShiftLeft(ConvertElementType(u32s[1], U64),
141 ConstantR0WithType(builder, U64, 32));
142 }
143
StatelessRngUniformU64(std::array<XlaOp,2> key,const Shape & shape)144 XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
145 XlaBuilder* builder = key[0].builder();
146 const int64 size = ShapeUtil::ElementsIn(shape);
147 ThreeFry2x32State inputs = GetInputs(size, builder);
148 ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
149 // low 32 bit: outputs[0], high 32 bit: outputs[1]
150 auto result = Uint32sToUint64(outputs);
151 return Reshape(result, AsInt64Slice(shape.dimensions()));
152 }
153
StatelessRngUniformF32(XlaOp bits,XlaOp minval,XlaOp maxval)154 XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
155 XlaBuilder* builder = bits.builder();
156
157 // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
158 // forces the random bits into the mantissa.
159 constexpr int kFloatBits = 32;
160 constexpr int kMantissaBits = 23;
161 bits = ShiftRightLogical(
162 bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
163 ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
164 auto floats = BitcastConvertType(bits, F32);
165
166 // We have a floating point number in the range [1.0, 2.0).
167 // Subtract 1.0f to shift to the range [0.0, 1.0)
168 floats = floats - ConstantR0<float>(builder, 1.0f);
169 // Multiply and add to shift to the range [minval, maxval).
170 return floats * (maxval - minval) + minval;
171 }
172
StatelessRngUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)173 XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
174 PrimitiveType type, PrimitiveType unsigned_type) {
175 XlaBuilder* builder = bits.builder();
176 auto range = BitcastConvertType(maxval, unsigned_type) -
177 BitcastConvertType(minval, unsigned_type);
178 auto dist = Rem(bits, range);
179 auto dist_div_2 =
180 ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
181
182 return minval + BitcastConvertType(dist_div_2, type) +
183 BitcastConvertType(dist - dist_div_2, type);
184 }
185
StatelessRngUniform(std::array<XlaOp,2> seeds,const Shape & shape,XlaOp minval,XlaOp maxval)186 XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
187 XlaOp minval, XlaOp maxval) {
188 XlaBuilder* builder = seeds[0].builder();
189 PrimitiveType type = shape.element_type();
190 switch (type) {
191 case F32: {
192 auto bits = StatelessRngUniformU32(seeds, shape);
193 return StatelessRngUniformF32(bits, minval, maxval);
194 }
195 case S32: {
196 auto bits = StatelessRngUniformU32(seeds, shape);
197 return StatelessRngUniformInt(bits, minval, maxval, type, U32);
198 }
199 case S64: {
200 auto bits = StatelessRngUniformU64(seeds, shape);
201 return StatelessRngUniformInt(bits, minval, maxval, type, U64);
202 }
203 default:
204 return builder->ReportError(Unimplemented(
205 "Types other than F32, S32 and S64 are not implemented by "
206 "StatelessRngUniform."));
207 }
208 }
209
210 } // namespace xla
211