• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/prng.h"
17 
18 #include <cmath>
19 
20 #include "absl/base/casts.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/math.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/util.h"
25 
26 namespace xla {
27 namespace {
28 
29 // Rotates a 32-bit integer 'v' left by 'distance' bits.
RotateLeftU32(XlaOp v,int distance)30 XlaOp RotateLeftU32(XlaOp v, int distance) {
31   return (v << ConstantR0<uint32>(v.builder(), distance)) |
32          ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
33 }
34 
35 }  // namespace
36 
ThreeFry2x32(ThreeFry2x32State input,ThreeFry2x32State key)37 ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
38   XlaBuilder* builder = input[0].builder();
39   key[0] = BitcastConvertType(key[0], U32);
40   key[1] = BitcastConvertType(key[1], U32);
41 
42   // Rotation distances specified by the Threefry2x32 algorithm.
43   constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
44   ThreeFry2x32State x;
45 
46   std::array<XlaOp, 3> ks;
47   // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
48   ks[2] = ConstantR0<uint32>(builder, 0x1BD11BDA);
49   for (int i = 0; i < 2; ++i) {
50     ks[i] = key[i];
51     x[i] = input[i];
52     ks[2] = ks[2] ^ key[i];
53   }
54 
55   x[0] = x[0] + ks[0];
56   x[1] = x[1] + ks[1];
57 
58   // Performs a single round of the Threefry2x32 algorithm, with a rotation
59   // amount 'rotation'.
60   auto round = [](ThreeFry2x32State v, int rotation) {
61     v[0] = v[0] + v[1];
62     v[1] = RotateLeftU32(v[1], rotation);
63     v[1] = v[0] ^ v[1];
64     return v;
65   };
66 
67   // There are no known statistical flaws with 13 rounds of Threefry2x32.
68   // We are conservative and use 20 rounds.
69   x = round(x, rotations[0]);
70   x = round(x, rotations[1]);
71   x = round(x, rotations[2]);
72   x = round(x, rotations[3]);
73   x[0] = x[0] + ks[1];
74   x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 1);
75 
76   x = round(x, rotations[4]);
77   x = round(x, rotations[5]);
78   x = round(x, rotations[6]);
79   x = round(x, rotations[7]);
80   x[0] = x[0] + ks[2];
81   x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 2);
82 
83   x = round(x, rotations[0]);
84   x = round(x, rotations[1]);
85   x = round(x, rotations[2]);
86   x = round(x, rotations[3]);
87   x[0] = x[0] + ks[0];
88   x[1] = x[1] + ks[1] + ConstantR0<uint32>(builder, 3);
89 
90   x = round(x, rotations[4]);
91   x = round(x, rotations[5]);
92   x = round(x, rotations[6]);
93   x = round(x, rotations[7]);
94   x[0] = x[0] + ks[1];
95   x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 4);
96 
97   x = round(x, rotations[0]);
98   x = round(x, rotations[1]);
99   x = round(x, rotations[2]);
100   x = round(x, rotations[3]);
101   x[0] = x[0] + ks[2];
102   x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 5);
103 
104   return x;
105 }
106 
107 // Returns the inputs with unique counter values for ThreeFry2x32.
GetInputs(const int64 size,XlaBuilder * builder)108 ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
109   ThreeFry2x32State inputs;
110   inputs[0] = Iota(builder, U32, size);
111   inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
112   return inputs;
113 }
114 
StatelessRngUniformU32(std::array<XlaOp,2> key,const Shape & shape)115 XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
116   XlaBuilder* builder = key[0].builder();
117   const int64 size = ShapeUtil::ElementsIn(shape);
118   const int64 half_size = CeilOfRatio<int64>(size, 2);
119   const bool size_is_odd = (half_size * 2 != size);
120   ThreeFry2x32State inputs = GetInputs(half_size, builder);
121   ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
122   if (size_is_odd) {
123     outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
124   }
125   auto result = ConcatInDim(builder, outputs, 0);
126   return Reshape(result, AsInt64Slice(shape.dimensions()));
127 }
128 
Uint64ToUint32s(XlaOp u64)129 ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
130   auto builder = u64.builder();
131   auto const32 = ConstantR0WithType(builder, U64, 32);
132   auto fst = ConvertElementType(u64, U32);
133   auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
134   return {fst, snd};
135 }
136 
Uint32sToUint64(ThreeFry2x32State u32s)137 XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
138   auto builder = u32s[0].builder();
139   return ConvertElementType(u32s[0], U64) |
140          ShiftLeft(ConvertElementType(u32s[1], U64),
141                    ConstantR0WithType(builder, U64, 32));
142 }
143 
StatelessRngUniformU64(std::array<XlaOp,2> key,const Shape & shape)144 XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
145   XlaBuilder* builder = key[0].builder();
146   const int64 size = ShapeUtil::ElementsIn(shape);
147   ThreeFry2x32State inputs = GetInputs(size, builder);
148   ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
149   // low 32 bit: outputs[0], high 32 bit: outputs[1]
150   auto result = Uint32sToUint64(outputs);
151   return Reshape(result, AsInt64Slice(shape.dimensions()));
152 }
153 
StatelessRngUniformF32(XlaOp bits,XlaOp minval,XlaOp maxval)154 XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
155   XlaBuilder* builder = bits.builder();
156 
157   // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
158   // forces the random bits into the mantissa.
159   constexpr int kFloatBits = 32;
160   constexpr int kMantissaBits = 23;
161   bits = ShiftRightLogical(
162              bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
163          ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
164   auto floats = BitcastConvertType(bits, F32);
165 
166   // We have a floating point number in the range [1.0, 2.0).
167   // Subtract 1.0f to shift to the range [0.0, 1.0)
168   floats = floats - ConstantR0<float>(builder, 1.0f);
169   // Multiply and add to shift to the range [minval, maxval).
170   return floats * (maxval - minval) + minval;
171 }
172 
StatelessRngUniformInt(XlaOp bits,XlaOp minval,XlaOp maxval,PrimitiveType type,PrimitiveType unsigned_type)173 XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
174                              PrimitiveType type, PrimitiveType unsigned_type) {
175   XlaBuilder* builder = bits.builder();
176   auto range = BitcastConvertType(maxval, unsigned_type) -
177                BitcastConvertType(minval, unsigned_type);
178   auto dist = Rem(bits, range);
179   auto dist_div_2 =
180       ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
181 
182   return minval + BitcastConvertType(dist_div_2, type) +
183          BitcastConvertType(dist - dist_div_2, type);
184 }
185 
StatelessRngUniform(std::array<XlaOp,2> seeds,const Shape & shape,XlaOp minval,XlaOp maxval)186 XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
187                           XlaOp minval, XlaOp maxval) {
188   XlaBuilder* builder = seeds[0].builder();
189   PrimitiveType type = shape.element_type();
190   switch (type) {
191     case F32: {
192       auto bits = StatelessRngUniformU32(seeds, shape);
193       return StatelessRngUniformF32(bits, minval, maxval);
194     }
195     case S32: {
196       auto bits = StatelessRngUniformU32(seeds, shape);
197       return StatelessRngUniformInt(bits, minval, maxval, type, U32);
198     }
199     case S64: {
200       auto bits = StatelessRngUniformU64(seeds, shape);
201       return StatelessRngUniformInt(bits, minval, maxval, type, U64);
202     }
203     default:
204       return builder->ReportError(Unimplemented(
205           "Types other than F32, S32 and S64 are not implemented by "
206           "StatelessRngUniform."));
207   }
208 }
209 
210 }  // namespace xla
211