1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_ 18 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/kernels/data/dataset_utils.h" 21 #include "tensorflow/core/lib/random/philox_random.h" 22 #include "tensorflow/core/lib/random/random.h" 23 #include "tensorflow/core/lib/random/random_distributions.h" 24 25 namespace tensorflow { 26 namespace data { 27 28 // Represents a pair of random seeds. By TensorFlow convention, if both seeds 29 // are 0, then pseudo-random values are used instead. 30 class RandomSeeds { 31 public: RandomSeeds(int64 seed,int64 seed2)32 RandomSeeds(int64 seed, int64 seed2) 33 : input_seed_(seed), 34 input_seed2_(seed2), 35 seed_((seed | seed2) == 0 ? random::New64() : seed), 36 seed2_((seed | seed2) == 0 ? random::New64() : seed2) {} 37 input_seed()38 int64 input_seed() const { return input_seed_; } input_seed2()39 int64 input_seed2() const { return input_seed2_; } seed()40 int64 seed() const { return seed_; } seed2()41 int64 seed2() const { return seed2_; } 42 43 private: 44 const int64 input_seed_; 45 const int64 input_seed2_; 46 const int64 seed_; 47 const int64 seed2_; 48 }; 49 50 // Base class for seed generator resources. Subclasses customize how seeds are 51 // generated. 52 class SeedGenerator { 53 public: ~SeedGenerator()54 virtual ~SeedGenerator() {} 55 56 virtual int64 seed() const = 0; 57 virtual int64 seed2() const = 0; 58 virtual bool reshuffle_each_iteration() const = 0; 59 60 virtual void GenerateSeeds(int64* seed1, int64* seed2) = 0; 61 virtual void Reset() = 0; 62 num_random_samples()63 virtual int64 num_random_samples() const { 64 tf_shared_lock l(mu_); 65 return num_random_samples_; 66 } set_num_random_samples(int64 num_random_samples)67 virtual void set_num_random_samples(int64 num_random_samples) { 68 mutex_lock l(mu_); 69 num_random_samples_ = num_random_samples; 70 } 71 72 protected: 73 mutable mutex mu_; 74 int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0; 75 }; 76 77 // A resource wrapping a shared instance of a seed generator. 78 class SeedGeneratorManager : public ResourceBase { 79 public: SeedGeneratorManager(SeedGenerator * seed_generator)80 explicit SeedGeneratorManager(SeedGenerator* seed_generator) 81 : seed_generator_(seed_generator) {} 82 83 std::string DebugString() const override; 84 get()85 std::shared_ptr<SeedGenerator> get() { return seed_generator_; } 86 87 private: 88 std::shared_ptr<SeedGenerator> seed_generator_; 89 }; 90 91 // Always generates the specified seed values. 92 class FixedSeedGenerator : public SeedGenerator { 93 public: FixedSeedGenerator(RandomSeeds seeds)94 explicit FixedSeedGenerator(RandomSeeds seeds) : seeds_(std::move(seeds)) {} 95 seed()96 int64 seed() const override { return seeds_.seed(); } seed2()97 int64 seed2() const override { return seeds_.seed(); } reshuffle_each_iteration()98 bool reshuffle_each_iteration() const override { return false; } 99 100 void GenerateSeeds(int64* seed1, int64* seed2) override; Reset()101 void Reset() override {} 102 103 private: 104 const RandomSeeds seeds_; 105 }; 106 107 // Generates different (but deterministically chosen) seed values. 108 class RandomSeedGenerator : public SeedGenerator { 109 public: RandomSeedGenerator(RandomSeeds seeds)110 explicit RandomSeedGenerator(RandomSeeds seeds) 111 : seeds_(std::move(seeds)), 112 parent_generator_(seeds_.seed(), seeds_.seed2()), 113 generator_(&parent_generator_) {} 114 seed()115 int64 seed() const override { return seeds_.seed(); } seed2()116 int64 seed2() const override { return seeds_.seed2(); } reshuffle_each_iteration()117 bool reshuffle_each_iteration() const override { return true; } 118 119 void GenerateSeeds(int64* seed1, int64* seed2) override; 120 void Reset() override; 121 122 private: 123 const RandomSeeds seeds_; 124 random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_); 125 random::SingleSampleAdapter<random::PhiloxRandom> generator_ 126 TF_GUARDED_BY(mu_); 127 }; 128 129 // Creates an instance of seed generator resource and transfers ownership 130 // to the caller. 131 class AnonymousSeedGeneratorHandleOp 132 : public AnonymousResourceOp<SeedGeneratorManager> { 133 public: 134 explicit AnonymousSeedGeneratorHandleOp(OpKernelConstruction* ctx); 135 void Compute(OpKernelContext* ctx) override; 136 137 private: 138 string name() override; 139 Status CreateResource(OpKernelContext* ctx, 140 std::unique_ptr<FunctionLibraryDefinition> flib_def, 141 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 142 FunctionLibraryRuntime* lib, 143 SeedGeneratorManager** manager) override; 144 145 mutex mu_; 146 std::unique_ptr<RandomSeeds> seeds_ TF_GUARDED_BY(mu_); 147 bool reshuffle_; 148 }; 149 150 // Deletes an instance of seed generator resource. 151 class DeleteSeedGeneratorOp : public OpKernel { 152 public: DeleteSeedGeneratorOp(OpKernelConstruction * ctx)153 explicit DeleteSeedGeneratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 154 155 void Compute(OpKernelContext* ctx) override; 156 }; 157 158 } // namespace data 159 } // namespace tensorflow 160 161 #endif // TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_ 162