1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ 18 19 #include <array> 20 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/xla_data.pb.h" 23 24 namespace xla { 25 26 // Records the bits and state generated by a random number generator. 27 struct RngOutput { 28 XlaOp value; 29 XlaOp state; 30 }; 31 32 // A BitGenerator returns random bits and updated random bit generator state. 33 // 34 // key: is a value input to a random number generator that can affect the 35 // sequence of number it will generate. A random number generator constructs 36 // its seed using the key and the initial state. The tf2xla bridge passes the 37 // seed operand of a tensorflow random operation as a key to the random bit 38 // generator, for example. 39 // initial_state: initial_state is the initial state of the current random 40 // number generation. It could be 0 for a stateless random operation, and 41 // the returned state from a previous execution for a stateful random 42 // operation. 43 // shape: the shape of the random bits. 44 using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state, 45 const xla::Shape& shape)>; 46 47 // Implements the ThreeFry counter-based PRNG algorithm. 48 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. 49 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 50 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, 51 const xla::Shape& shape); 52 53 // Implements the Philox algorithm to generate random numbers in parallel. 54 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. 55 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 56 // 57 // The paper presents a few variants of the Philox algorithm, we picked the 58 // 4x32_10 version of the algorithm for the following reasons: 59 // . 4x32 uses 32-bit multiplication which is fast on GPUs. 60 // . The authors recommend the 10-round variant, and TensorFlow also uses it. 61 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, 62 const Shape& shape); 63 // Returns a scrambled pair of (state, key) from a single key. 64 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key); 65 66 // Uses the given bit generator to generate random bits and then converts the 67 // random bits to random numbers of uniform distribution in the given range. 68 // Returns the random numbers and the state of the random number generator. 69 // This function is for shape with floating point element types. 70 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, 71 BitGeneratorTy bit_generator, 72 XlaOp minval, XlaOp maxval, 73 const xla::Shape& shape); 74 75 // Similar to UniformFloatingPointDistribution but for shape with integer 76 // element types. 77 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, 78 BitGeneratorTy bit_generator, XlaOp minval, 79 XlaOp maxval, const xla::Shape& shape); 80 81 // Uses the given bit generator to generate random bits and then converts the 82 // random bits to random numbers of normal distribution. 83 // Returns the random numbers and the state of the random number generator. 84 RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, 85 BitGeneratorTy bit_generator, 86 const xla::Shape& shape); 87 88 // Concatenates scalars into a vector. 89 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, 90 absl::Span<const xla::XlaOp> scalars); 91 92 // Increases Philox counter (an uint128) by a delta (an uint64). 93 xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); 94 95 } // namespace xla 96 97 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ 98