• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations that generate constants.
16
17See the [constants guide](https://tensorflow.org/api_guides/python/constant_op).
18"""
19
20# Must be separate from array_ops to avoid a cyclic dependency.
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.framework import types_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import execute
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import op_callbacks
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.profiler import trace
36from tensorflow.python.util.tf_export import tf_export
37
38
39def _eager_reshape(tensor, shape, ctx):
40  """Eager-only version of Reshape op; requires tensor is an eager Tensor."""
41  attr_t = tensor._datatype_enum()  # pylint: disable=protected-access
42  attr_tshape, (shape,) = execute.args_to_matching_eager(
43      [shape], ctx, [dtypes.int32, dtypes.int64], dtypes.int32)
44  inputs_flat = [tensor, shape]
45  attrs = ("T", attr_t, "Tshape", attr_tshape)
46  result, = execute.execute(
47      b"Reshape", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
48  return result
49
50
51def _eager_fill(dims, value, ctx):
52  """Eager-only version of Fill op; requires value is an eager Tensor."""
53  attr_t = value.dtype.as_datatype_enum
54  dims = convert_to_eager_tensor(dims, ctx, dtypes.int32)
55  inputs_flat = [dims, value]
56  attrs = ("T", attr_t, "index_type", types_pb2.DT_INT32)
57  result, = execute.execute(
58      b"Fill", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
59  return result
60
61
62def _eager_identity(tensor, ctx):
63  """Eager-only version of Identity op; requires tensor is an eager Tensor."""
64  attrs = ("T", tensor.dtype.as_datatype_enum)
65  result, = execute.execute(
66      b"Identity", 1, inputs=[tensor], attrs=attrs, ctx=ctx)
67  return result
68
69
70def _eager_const(tensor, ctx):
71  """Copy a constant to the current device."""
72  attrs = ("T", tensor.dtype.as_datatype_enum)
73  result, = execute.execute(
74      b"_EagerConst", 1, inputs=[tensor], attrs=attrs, ctx=ctx)
75  return result
76
77
78def convert_to_eager_tensor(value, ctx, dtype=None):
79  """Converts the given `value` to an `EagerTensor`.
80
81  Note that this function could return cached copies of created constants for
82  performance reasons.
83
84  Args:
85    value: value to convert to EagerTensor.
86    ctx: value of context.context().
87    dtype: optional desired dtype of the converted EagerTensor.
88
89  Returns:
90    EagerTensor created from value.
91
92  Raises:
93    TypeError: if `dtype` is not compatible with the type of t.
94  """
95  if isinstance(value, ops.EagerTensor):
96    if dtype is not None and value.dtype != dtype:
97      raise TypeError("Expected tensor with type %r not %r" % (
98          dtype, value.dtype))
99    return value
100  if dtype is not None:
101    try:
102      dtype = dtype.as_datatype_enum
103    except AttributeError:
104      dtype = dtypes.as_dtype(dtype).as_datatype_enum
105  ctx.ensure_initialized()
106  return ops.EagerTensor(value, ctx.device_name, dtype)
107
108
109@tf_export(v1=["constant"])
110def constant_v1(
111    value, dtype=None, shape=None, name="Const", verify_shape=False):
112  """Creates a constant tensor.
113
114  The resulting tensor is populated with values of type `dtype`, as
115  specified by arguments `value` and (optionally) `shape` (see examples
116  below).
117
118  The argument `value` can be a constant value, or a list of values of type
119  `dtype`. If `value` is a list, then the length of the list must be less
120  than or equal to the number of elements implied by the `shape` argument (if
121  specified). In the case where the list length is less than the number of
122  elements specified by `shape`, the last element in the list will be used
123  to fill the remaining entries.
124
125  The argument `shape` is optional. If present, it specifies the dimensions of
126  the resulting tensor. If not present, the shape of `value` is used.
127
128  If the argument `dtype` is not specified, then the type is inferred from
129  the type of `value`.
130
131  For example:
132
133  ```python
134  # Constant 1-D Tensor populated with value list.
135  tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]
136
137  # Constant 2-D tensor populated with scalar value -1.
138  tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
139                                               [-1. -1. -1.]]
140  ```
141
142  `tf.constant` differs from `tf.fill` in a few ways:
143
144  *   `tf.constant` supports arbitrary constants, not just uniform scalar
145      Tensors like `tf.fill`.
146  *   `tf.constant` creates a `Const` node in the computation graph with the
147      exact value at graph construction time. On the other hand, `tf.fill`
148      creates an Op in the graph that is expanded at runtime.
149  *   Because `tf.constant` only embeds constant values in the graph, it does
150      not support dynamic shapes based on other runtime Tensors, whereas
151      `tf.fill` does.
152
153  Args:
154    value:          A constant value (or list) of output type `dtype`.
155
156    dtype:          The type of the elements of the resulting tensor.
157
158    shape:          Optional dimensions of resulting tensor.
159
160    name:           Optional name for the tensor.
161
162    verify_shape:   Boolean that enables verification of a shape of values.
163
164  Returns:
165    A Constant Tensor.
166
167  Raises:
168    TypeError: if shape is incorrectly specified or unsupported.
169  """
170  return _constant_impl(value, dtype, shape, name, verify_shape=verify_shape,
171                        allow_broadcast=False)
172
173
174@tf_export("constant", v1=[])
175def constant(value, dtype=None, shape=None, name="Const"):
176  """Creates a constant tensor from a tensor-like object.
177
178  Note: All eager `tf.Tensor` values are immutable (in contrast to
179  `tf.Variable`). There is nothing especially _constant_ about the value
180  returned from `tf.constant`. This function is not fundamentally different from
181  `tf.convert_to_tensor`. The name `tf.constant` comes from the `value` being
182  embedded in a `Const` node in the `tf.Graph`. `tf.constant` is useful
183  for asserting that the value can be embedded that way.
184
185  If the argument `dtype` is not specified, then the type is inferred from
186  the type of `value`.
187
188  >>> # Constant 1-D Tensor from a python list.
189  >>> tf.constant([1, 2, 3, 4, 5, 6])
190  <tf.Tensor: shape=(6,), dtype=int32,
191      numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
192  >>> # Or a numpy array
193  >>> a = np.array([[1, 2, 3], [4, 5, 6]])
194  >>> tf.constant(a)
195  <tf.Tensor: shape=(2, 3), dtype=int64, numpy=
196    array([[1, 2, 3],
197           [4, 5, 6]])>
198
199  If `dtype` is specified, the resulting tensor values are cast to the requested
200  `dtype`.
201
202  >>> tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.float64)
203  <tf.Tensor: shape=(6,), dtype=float64,
204      numpy=array([1., 2., 3., 4., 5., 6.])>
205
206  If `shape` is set, the `value` is reshaped to match. Scalars are expanded to
207  fill the `shape`:
208
209  >>> tf.constant(0, shape=(2, 3))
210    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
211    array([[0, 0, 0],
212           [0, 0, 0]], dtype=int32)>
213  >>> tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
214  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
215    array([[1, 2, 3],
216           [4, 5, 6]], dtype=int32)>
217
218  `tf.constant` has no effect if an eager Tensor is passed as the `value`, it
219  even transmits gradients:
220
221  >>> v = tf.Variable([0.0])
222  >>> with tf.GradientTape() as g:
223  ...     loss = tf.constant(v + v)
224  >>> g.gradient(loss, v).numpy()
225  array([2.], dtype=float32)
226
227  But, since `tf.constant` embeds the value in the `tf.Graph` this fails for
228  symbolic tensors:
229
230  >>> with tf.compat.v1.Graph().as_default():
231  ...   i = tf.compat.v1.placeholder(shape=[None, None], dtype=tf.float32)
232  ...   t = tf.constant(i)
233  Traceback (most recent call last):
234  ...
235  TypeError: ...
236
237  `tf.constant` will create tensors on the current device. Inputs which are
238  already tensors maintain their placements unchanged.
239
240  Related Ops:
241
242  * `tf.convert_to_tensor` is similar but:
243    * It has no `shape` argument.
244    * Symbolic tensors are allowed to pass through.
245
246    >>> with tf.compat.v1.Graph().as_default():
247    ...   i = tf.compat.v1.placeholder(shape=[None, None], dtype=tf.float32)
248    ...   t = tf.convert_to_tensor(i)
249
250  * `tf.fill`: differs in a few ways:
251    *   `tf.constant` supports arbitrary constants, not just uniform scalar
252        Tensors like `tf.fill`.
253    *   `tf.fill` creates an Op in the graph that is expanded at runtime, so it
254        can efficiently represent large tensors.
255    *   Since `tf.fill` does not embed the value, it can produce dynamically
256        sized outputs.
257
258  Args:
259    value: A constant value (or list) of output type `dtype`.
260    dtype: The type of the elements of the resulting tensor.
261    shape: Optional dimensions of resulting tensor.
262    name: Optional name for the tensor.
263
264  Returns:
265    A Constant Tensor.
266
267  Raises:
268    TypeError: if shape is incorrectly specified or unsupported.
269    ValueError: if called on a symbolic tensor.
270  """
271  return _constant_impl(value, dtype, shape, name, verify_shape=False,
272                        allow_broadcast=True)
273
274
275def _constant_impl(
276    value, dtype, shape, name, verify_shape, allow_broadcast):
277  """Implementation of constant."""
278  ctx = context.context()
279  if ctx.executing_eagerly():
280    if trace.enabled:
281      with trace.Trace("tf.constant"):
282        return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
283    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
284
285  g = ops.get_default_graph()
286  tensor_value = attr_value_pb2.AttrValue()
287  tensor_value.tensor.CopyFrom(
288      tensor_util.make_tensor_proto(
289          value, dtype=dtype, shape=shape, verify_shape=verify_shape,
290          allow_broadcast=allow_broadcast))
291  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
292  attrs = {"value": tensor_value, "dtype": dtype_value}
293  const_tensor = g._create_op_internal(  # pylint: disable=protected-access
294      "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0]
295
296  if op_callbacks.should_invoke_op_callbacks():
297    # TODO(b/147670703): Once the special-op creation code paths
298    # are unified. Remove this `if` block.
299    callback_outputs = op_callbacks.invoke_op_callbacks(
300        "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g)
301    if callback_outputs is not None:
302      const_tensor, = callback_outputs
303  return const_tensor
304
305
306def _constant_eager_impl(ctx, value, dtype, shape, verify_shape):
307  """Creates a constant on the current device."""
308  t = convert_to_eager_tensor(value, ctx, dtype)
309  if shape is None:
310    return t
311  shape = tensor_shape.as_shape(shape)
312  if shape == t.shape:
313    return t
314  if verify_shape:
315    raise TypeError("Expected Tensor's shape: %s, got %s." %
316                    (tuple(shape), tuple(t.shape)))
317  num_t = t.shape.num_elements()
318  # TODO(josh11b): Implement shape -> eager tensor conversion.
319  if num_t == shape.num_elements():
320    return _eager_reshape(t, shape.as_list(), ctx)
321  if num_t == 1:
322    if t.dtype == dtypes.bool:
323      # We don't have a Fill kernel for bool dtype on GPU. So we first run
324      # Fill on CPU and then copy to GPU if needed.
325      with ops.device("/device:CPU:0"):
326        x = _eager_fill(shape.as_list(), _eager_identity(t, ctx), ctx)
327      return _eager_identity(x, ctx)
328    else:
329      return _eager_fill(shape.as_list(), t, ctx)
330  raise TypeError("Eager execution of tf.constant with unsupported shape "
331                  "(value has %d elements, shape is %s with %d elements)." %
332                  (num_t, shape, shape.num_elements()))
333
334
335def is_constant(tensor_or_op):
336  if isinstance(tensor_or_op, ops.Tensor):
337    op = tensor_or_op.op
338  else:
339    op = tensor_or_op
340  return op.type == "Const"
341
342
343def _constant_tensor_conversion_function(v, dtype=None, name=None,
344                                         as_ref=False):
345  _ = as_ref
346  return constant(v, dtype=dtype, name=name)
347
348
349ops.register_tensor_conversion_function(
350    (list, tuple), _constant_tensor_conversion_function, 100)
351ops.register_tensor_conversion_function(
352    object, _constant_tensor_conversion_function, 200)
353
354
355def _tensor_shape_tensor_conversion_function(s,
356                                             dtype=None,
357                                             name=None,
358                                             as_ref=False):
359  """Function to convert TensorShape to Tensor."""
360  _ = as_ref
361  if not s.is_fully_defined():
362    raise ValueError(
363        "Cannot convert a partially known TensorShape to a Tensor: %s" % s)
364  s_list = s.as_list()
365  int64_value = 0
366  for dim in s_list:
367    if dim >= 2**31:
368      int64_value = dim
369      break
370
371  if dtype is not None:
372    if dtype not in (dtypes.int32, dtypes.int64):
373      raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
374    if dtype == dtypes.int32 and int64_value:
375      raise ValueError("Cannot convert a TensorShape to dtype int32; "
376                       "a dimension is too large (%s)" % int64_value)
377  else:
378    dtype = dtypes.int64 if int64_value else dtypes.int32
379  if name is None:
380    name = "shape_as_tensor"
381  return constant(s_list, dtype=dtype, name=name)
382
383
384ops.register_tensor_conversion_function(
385    tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
386
387
388def _dimension_tensor_conversion_function(d,
389                                          dtype=None,
390                                          name=None,
391                                          as_ref=False):
392  """Function to convert Dimension to Tensor."""
393  _ = as_ref
394  if d.value is None:
395    raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
396  if dtype is not None:
397    if dtype not in (dtypes.int32, dtypes.int64):
398      raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
399  else:
400    dtype = dtypes.int32
401  if name is None:
402    name = "shape_as_tensor"
403  return constant(d.value, dtype=dtype, name=name)
404
405
406ops.register_tensor_conversion_function(
407    tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
408