• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""For seeding individual ops based on a graph-level seed.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
27
28
29DEFAULT_GRAPH_SEED = 87654321
30_MAXINT32 = 2**31 - 1
31
32
33def _truncate_seed(seed):
34  return seed % _MAXINT32  # Truncate to fit into 32-bit integer
35
36
37@tf_export(v1=['random.get_seed', 'get_seed'])
38@deprecation.deprecated_endpoints('get_seed')
39def get_seed(op_seed):
40  """Returns the local seeds an operation should use given an op-specific seed.
41
42  Given operation-specific seed, `op_seed`, this helper function returns two
43  seeds derived from graph-level and op-level seeds. Many random operations
44  internally use the two seeds to allow user to change the seed globally for a
45  graph, or for only specific operations.
46
47  For details on how the graph-level seed interacts with op seeds, see
48  `tf.random.set_random_seed`.
49
50  Args:
51    op_seed: integer.
52
53  Returns:
54    A tuple of two integers that should be used for the local seed of this
55    operation.
56  """
57  eager = context.executing_eagerly()
58
59  if eager:
60    global_seed = context.global_seed()
61  else:
62    global_seed = ops.get_default_graph().seed
63
64  if global_seed is not None:
65    if op_seed is None:
66      # pylint: disable=protected-access
67      if hasattr(ops.get_default_graph(), '_seed_used'):
68        ops.get_default_graph()._seed_used = True
69      if eager:
70        op_seed = context.internal_operation_seed()
71      else:
72        op_seed = ops.get_default_graph()._last_id
73
74    seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
75  else:
76    if op_seed is not None:
77      seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
78    else:
79      seeds = None, None
80  # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
81  # be unexpected since Python docs say nondeterminism is (None, None).
82  if seeds == (0, 0):
83    return (0, _MAXINT32)
84  return seeds
85
86
87@tf_export(v1=['random.set_random_seed', 'set_random_seed'])
88def set_random_seed(seed):
89  """Sets the graph-level random seed.
90
91  Operations that rely on a random seed actually derive it from two seeds:
92  the graph-level and operation-level seeds. This sets the graph-level seed.
93
94  Its interactions with operation-level seeds is as follows:
95
96    1. If neither the graph-level nor the operation seed is set:
97      A random seed is used for this op.
98    2. If the graph-level seed is set, but the operation seed is not:
99      The system deterministically picks an operation seed in conjunction
100      with the graph-level seed so that it gets a unique random sequence.
101    3. If the graph-level seed is not set, but the operation seed is set:
102      A default graph-level seed and the specified operation seed are used to
103      determine the random sequence.
104    4. If both the graph-level and the operation seed are set:
105      Both seeds are used in conjunction to determine the random sequence.
106
107  To illustrate the user-visible effects, consider these examples:
108
109  To generate different sequences across sessions, set neither
110  graph-level nor op-level seeds:
111
112  ```python
113  a = tf.random_uniform([1])
114  b = tf.random_normal([1])
115
116  print("Session 1")
117  with tf.Session() as sess1:
118    print(sess1.run(a))  # generates 'A1'
119    print(sess1.run(a))  # generates 'A2'
120    print(sess1.run(b))  # generates 'B1'
121    print(sess1.run(b))  # generates 'B2'
122
123  print("Session 2")
124  with tf.Session() as sess2:
125    print(sess2.run(a))  # generates 'A3'
126    print(sess2.run(a))  # generates 'A4'
127    print(sess2.run(b))  # generates 'B3'
128    print(sess2.run(b))  # generates 'B4'
129  ```
130
131  To generate the same repeatable sequence for an op across sessions, set the
132  seed for the op:
133
134  ```python
135  a = tf.random_uniform([1], seed=1)
136  b = tf.random_normal([1])
137
138  # Repeatedly running this block with the same graph will generate the same
139  # sequence of values for 'a', but different sequences of values for 'b'.
140  print("Session 1")
141  with tf.Session() as sess1:
142    print(sess1.run(a))  # generates 'A1'
143    print(sess1.run(a))  # generates 'A2'
144    print(sess1.run(b))  # generates 'B1'
145    print(sess1.run(b))  # generates 'B2'
146
147  print("Session 2")
148  with tf.Session() as sess2:
149    print(sess2.run(a))  # generates 'A1'
150    print(sess2.run(a))  # generates 'A2'
151    print(sess2.run(b))  # generates 'B3'
152    print(sess2.run(b))  # generates 'B4'
153  ```
154
155  To make the random sequences generated by all ops be repeatable across
156  sessions, set a graph-level seed:
157
158  ```python
159  tf.random.set_random_seed(1234)
160  a = tf.random_uniform([1])
161  b = tf.random_normal([1])
162
163  # Repeatedly running this block with the same graph will generate the same
164  # sequences of 'a' and 'b'.
165  print("Session 1")
166  with tf.Session() as sess1:
167    print(sess1.run(a))  # generates 'A1'
168    print(sess1.run(a))  # generates 'A2'
169    print(sess1.run(b))  # generates 'B1'
170    print(sess1.run(b))  # generates 'B2'
171
172  print("Session 2")
173  with tf.Session() as sess2:
174    print(sess2.run(a))  # generates 'A1'
175    print(sess2.run(a))  # generates 'A2'
176    print(sess2.run(b))  # generates 'B1'
177    print(sess2.run(b))  # generates 'B2'
178  ```
179
180  Args:
181    seed: integer.
182  """
183  if context.executing_eagerly():
184    context.set_global_seed(seed)
185  else:
186    ops.get_default_graph().seed = seed
187
188
189@tf_export('random.set_seed', v1=[])
190def set_seed(seed):
191  """Sets the graph-level random seed.
192
193  Operations that rely on a random seed actually derive it from two seeds:
194  the graph-level and operation-level seeds. This sets the graph-level seed.
195
196  Its interactions with operation-level seeds is as follows:
197
198    1. If neither the graph-level nor the operation seed is set:
199      A random seed is used for this op.
200    2. If the graph-level seed is set, but the operation seed is not:
201      The system deterministically picks an operation seed in conjunction
202      with the graph-level seed so that it gets a unique random sequence.
203    3. If the graph-level seed is not set, but the operation seed is set:
204      A default graph-level seed and the specified operation seed are used to
205      determine the random sequence.
206    4. If both the graph-level and the operation seed are set:
207      Both seeds are used in conjunction to determine the random sequence.
208
209  To illustrate the user-visible effects, consider these examples:
210
211  To generate different sequences across sessions, set neither
212  graph-level nor op-level seeds:
213
214  ```python
215  a = tf.random_uniform([1])
216  b = tf.random_normal([1])
217
218  print("Session 1")
219  with tf.Session() as sess1:
220    print(sess1.run(a))  # generates 'A1'
221    print(sess1.run(a))  # generates 'A2'
222    print(sess1.run(b))  # generates 'B1'
223    print(sess1.run(b))  # generates 'B2'
224
225  print("Session 2")
226  with tf.Session() as sess2:
227    print(sess2.run(a))  # generates 'A3'
228    print(sess2.run(a))  # generates 'A4'
229    print(sess2.run(b))  # generates 'B3'
230    print(sess2.run(b))  # generates 'B4'
231  ```
232
233  To generate the same repeatable sequence for an op across sessions, set the
234  seed for the op:
235
236  ```python
237  a = tf.random_uniform([1], seed=1)
238  b = tf.random_normal([1])
239
240  # Repeatedly running this block with the same graph will generate the same
241  # sequence of values for 'a', but different sequences of values for 'b'.
242  print("Session 1")
243  with tf.Session() as sess1:
244    print(sess1.run(a))  # generates 'A1'
245    print(sess1.run(a))  # generates 'A2'
246    print(sess1.run(b))  # generates 'B1'
247    print(sess1.run(b))  # generates 'B2'
248
249  print("Session 2")
250  with tf.Session() as sess2:
251    print(sess2.run(a))  # generates 'A1'
252    print(sess2.run(a))  # generates 'A2'
253    print(sess2.run(b))  # generates 'B3'
254    print(sess2.run(b))  # generates 'B4'
255  ```
256
257  To make the random sequences generated by all ops be repeatable across
258  sessions, set a graph-level seed:
259
260  ```python
261  tf.random.set_seed(1234)
262  a = tf.random_uniform([1])
263  b = tf.random_normal([1])
264
265  # Repeatedly running this block with the same graph will generate the same
266  # sequences of 'a' and 'b'.
267  print("Session 1")
268  with tf.Session() as sess1:
269    print(sess1.run(a))  # generates 'A1'
270    print(sess1.run(a))  # generates 'A2'
271    print(sess1.run(b))  # generates 'B1'
272    print(sess1.run(b))  # generates 'B2'
273
274  print("Session 2")
275  with tf.Session() as sess2:
276    print(sess2.run(a))  # generates 'A1'
277    print(sess2.run(a))  # generates 'A2'
278    print(sess2.run(b))  # generates 'B1'
279    print(sess2.run(b))  # generates 'B2'
280  ```
281
282  Args:
283    seed: integer.
284  """
285  # TODO(go/tf2-random): change doc, update to match design doc
286  set_random_seed(seed)
287