• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
18 
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/kernels/data/dataset_utils.h"
21 #include "tensorflow/core/lib/random/philox_random.h"
22 #include "tensorflow/core/lib/random/random.h"
23 #include "tensorflow/core/lib/random/random_distributions.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 // Represents a pair of random seeds. By TensorFlow convention, if both seeds
29 // are 0, then pseudo-random values are used instead.
30 class RandomSeeds {
31  public:
RandomSeeds(int64 seed,int64 seed2)32   RandomSeeds(int64 seed, int64 seed2)
33       : input_seed_(seed),
34         input_seed2_(seed2),
35         seed_((seed | seed2) == 0 ? random::New64() : seed),
36         seed2_((seed | seed2) == 0 ? random::New64() : seed2) {}
37 
input_seed()38   int64 input_seed() const { return input_seed_; }
input_seed2()39   int64 input_seed2() const { return input_seed2_; }
seed()40   int64 seed() const { return seed_; }
seed2()41   int64 seed2() const { return seed2_; }
42 
43  private:
44   const int64 input_seed_;
45   const int64 input_seed2_;
46   const int64 seed_;
47   const int64 seed2_;
48 };
49 
50 // Base class for seed generator resources. Subclasses customize how seeds are
51 // generated.
52 class SeedGenerator {
53  public:
~SeedGenerator()54   virtual ~SeedGenerator() {}
55 
56   virtual int64 seed() const = 0;
57   virtual int64 seed2() const = 0;
58   virtual bool reshuffle_each_iteration() const = 0;
59 
60   virtual void GenerateSeeds(int64* seed1, int64* seed2) = 0;
61   virtual void Reset() = 0;
62 
num_random_samples()63   virtual int64 num_random_samples() const {
64     tf_shared_lock l(mu_);
65     return num_random_samples_;
66   }
set_num_random_samples(int64 num_random_samples)67   virtual void set_num_random_samples(int64 num_random_samples) {
68     mutex_lock l(mu_);
69     num_random_samples_ = num_random_samples;
70   }
71 
72  protected:
73   mutable mutex mu_;
74   int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0;
75 };
76 
77 // A resource wrapping a shared instance of a seed generator.
78 class SeedGeneratorManager : public ResourceBase {
79  public:
SeedGeneratorManager(SeedGenerator * seed_generator)80   explicit SeedGeneratorManager(SeedGenerator* seed_generator)
81       : seed_generator_(seed_generator) {}
82 
83   std::string DebugString() const override;
84 
get()85   std::shared_ptr<SeedGenerator> get() { return seed_generator_; }
86 
87  private:
88   std::shared_ptr<SeedGenerator> seed_generator_;
89 };
90 
91 // Always generates the specified seed values.
92 class FixedSeedGenerator : public SeedGenerator {
93  public:
FixedSeedGenerator(RandomSeeds seeds)94   explicit FixedSeedGenerator(RandomSeeds seeds) : seeds_(std::move(seeds)) {}
95 
seed()96   int64 seed() const override { return seeds_.seed(); }
seed2()97   int64 seed2() const override { return seeds_.seed(); }
reshuffle_each_iteration()98   bool reshuffle_each_iteration() const override { return false; }
99 
100   void GenerateSeeds(int64* seed1, int64* seed2) override;
Reset()101   void Reset() override {}
102 
103  private:
104   const RandomSeeds seeds_;
105 };
106 
107 // Generates different (but deterministically chosen) seed values.
108 class RandomSeedGenerator : public SeedGenerator {
109  public:
RandomSeedGenerator(RandomSeeds seeds)110   explicit RandomSeedGenerator(RandomSeeds seeds)
111       : seeds_(std::move(seeds)),
112         parent_generator_(seeds_.seed(), seeds_.seed2()),
113         generator_(&parent_generator_) {}
114 
seed()115   int64 seed() const override { return seeds_.seed(); }
seed2()116   int64 seed2() const override { return seeds_.seed2(); }
reshuffle_each_iteration()117   bool reshuffle_each_iteration() const override { return true; }
118 
119   void GenerateSeeds(int64* seed1, int64* seed2) override;
120   void Reset() override;
121 
122  private:
123   const RandomSeeds seeds_;
124   random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
125   random::SingleSampleAdapter<random::PhiloxRandom> generator_
126       TF_GUARDED_BY(mu_);
127 };
128 
129 // Creates an instance of seed generator resource and transfers ownership
130 // to the caller.
131 class AnonymousSeedGeneratorHandleOp
132     : public AnonymousResourceOp<SeedGeneratorManager> {
133  public:
134   explicit AnonymousSeedGeneratorHandleOp(OpKernelConstruction* ctx);
135   void Compute(OpKernelContext* ctx) override;
136 
137  private:
138   string name() override;
139   Status CreateResource(OpKernelContext* ctx,
140                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
141                         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
142                         FunctionLibraryRuntime* lib,
143                         SeedGeneratorManager** manager) override;
144 
145   mutex mu_;
146   std::unique_ptr<RandomSeeds> seeds_ TF_GUARDED_BY(mu_);
147   bool reshuffle_;
148 };
149 
150 // Deletes an instance of seed generator resource.
151 class DeleteSeedGeneratorOp : public OpKernel {
152  public:
DeleteSeedGeneratorOp(OpKernelConstruction * ctx)153   explicit DeleteSeedGeneratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
154 
155   void Compute(OpKernelContext* ctx) override;
156 };
157 
158 }  // namespace data
159 }  // namespace tensorflow
160 
161 #endif  // TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_
162