• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations that generate constants.
16
17See the [constants guide](https://tensorflow.org/api_guides/python/constant_op).
18"""
19
20# Must be separate from array_ops to avoid a cyclic dependency.
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import numpy as np
27import six
28
29from tensorflow.core.framework import attr_value_pb2
30from tensorflow.core.framework import types_pb2
31from tensorflow.python.eager import context
32from tensorflow.python.eager import execute
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.util.tf_export import tf_export
38
39
40def _eager_reshape(tensor, shape, ctx):
41  """Eager-only version of Reshape op; requires tensor is an eager Tensor."""
42  attr_t = tensor._datatype_enum()  # pylint: disable=protected-access
43  attr_tshape, (shape,) = execute.args_to_matching_eager(
44      [shape], ctx, dtypes.int32)
45  inputs_flat = [tensor, shape]
46  attrs = ("T", attr_t, "Tshape", attr_tshape)
47  result, = execute.execute(
48      b"Reshape", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
49  return result
50
51
52def _eager_fill(dims, value, ctx):
53  """Eager-only version of Fill op; requires value is an eager Tensor."""
54  attr_t = value.dtype.as_datatype_enum
55  dims = convert_to_eager_tensor(dims, ctx, dtypes.int32)
56  inputs_flat = [dims, value]
57  attrs = ("T", attr_t, "index_type", types_pb2.DT_INT32)
58  result, = execute.execute(
59      b"Fill", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
60  return result
61
62
63def _eager_identity(tensor, ctx):
64  """Eager-only version of Identity op; requires tensor is an eager Tensor."""
65  attrs = ("T", tensor.dtype.as_datatype_enum)
66  result, = execute.execute(
67      b"Identity", 1, inputs=[tensor], attrs=attrs, ctx=ctx)
68  return result
69
70
71def convert_to_eager_tensor(value, ctx, dtype=None):
72  """Converts the given `value` to an `EagerTensor`.
73
74  Note that this function could return cached copies of created constants for
75  performance reasons.
76
77  Args:
78    value: value to convert to EagerTensor.
79    ctx: value of context.context().
80    dtype: optional desired dtype of the converted EagerTensor.
81
82  Returns:
83    EagerTensor created from value.
84
85  Raises:
86    TypeError: if `dtype` is not compatible with the type of t.
87  """
88  if isinstance(value, ops.EagerTensor):
89    if dtype is not None and value.dtype != dtype:
90      raise TypeError("Expected tensor with type %r not %r" % (
91          dtype, value.dtype))
92    return value
93  if dtype is not None:
94    try:
95      dtype = dtype.as_datatype_enum
96    except AttributeError:
97      dtype = dtypes.as_dtype(dtype).as_datatype_enum
98  device = ctx.device_name
99  handle = ctx._handle  # pylint: disable=protected-access
100  if isinstance(value, (float,) + six.integer_types):
101    # Use a scalar cache. This will put each scalar of each type only once on
102    # each device. Scalars don't use much device memory but copying scalars can
103    # trigger memcpys which are slow.
104    cache_key = device, value, dtype, type(value)
105    scalar_cache = ctx.scalar_cache()
106    tensor = scalar_cache.get(cache_key, None)
107    if tensor is not None:
108      return ops.EagerTensor(
109          value, handle, device, dtype, tensor)
110    t = ops.EagerTensor(value, handle, device, dtype)
111    scalar_cache[cache_key] = t
112    return t
113  else:
114    return ops.EagerTensor(value, handle, device, dtype)
115
116
117@tf_export(v1=["constant"])
118def constant_v1(
119    value, dtype=None, shape=None, name="Const", verify_shape=False):
120  """Creates a constant tensor.
121
122  The resulting tensor is populated with values of type `dtype`, as
123  specified by arguments `value` and (optionally) `shape` (see examples
124  below).
125
126  The argument `value` can be a constant value, or a list of values of type
127  `dtype`. If `value` is a list, then the length of the list must be less
128  than or equal to the number of elements implied by the `shape` argument (if
129  specified). In the case where the list length is less than the number of
130  elements specified by `shape`, the last element in the list will be used
131  to fill the remaining entries.
132
133  The argument `shape` is optional. If present, it specifies the dimensions of
134  the resulting tensor. If not present, the shape of `value` is used.
135
136  If the argument `dtype` is not specified, then the type is inferred from
137  the type of `value`.
138
139  For example:
140
141  ```python
142  # Constant 1-D Tensor populated with value list.
143  tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]
144
145  # Constant 2-D tensor populated with scalar value -1.
146  tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
147                                               [-1. -1. -1.]]
148  ```
149
150  `tf.constant` differs from `tf.fill` in a few ways:
151
152  *   `tf.constant` supports arbitrary constants, not just uniform scalar
153      Tensors like `tf.fill`.
154  *   `tf.constant` creates a `Const` node in the computation graph with the
155      exact value at graph construction time. On the other hand, `tf.fill`
156      creates an Op in the graph that is expanded at runtime.
157  *   Because `tf.constant` only embeds constant values in the graph, it does
158      not support dynamic shapes based on other runtime Tensors, whereas
159      `tf.fill` does.
160
161  Args:
162    value:          A constant value (or list) of output type `dtype`.
163
164    dtype:          The type of the elements of the resulting tensor.
165
166    shape:          Optional dimensions of resulting tensor.
167
168    name:           Optional name for the tensor.
169
170    verify_shape:   Boolean that enables verification of a shape of values.
171
172  Returns:
173    A Constant Tensor.
174
175  Raises:
176    TypeError: if shape is incorrectly specified or unsupported.
177  """
178  return _constant_impl(value, dtype, shape, name, verify_shape=verify_shape,
179                        allow_broadcast=False)
180
181
182@tf_export("constant", v1=[])
183def constant(value, dtype=None, shape=None, name="Const"):
184  """Creates a constant tensor.
185
186  The resulting tensor is populated with values of type `dtype`, as
187  specified by arguments `value` and (optionally) `shape` (see examples
188  below).
189
190  The argument `value` can be a constant value, or a list of values of type
191  `dtype`. If `value` is a list, then the length of the list must be less
192  than or equal to the number of elements implied by the `shape` argument (if
193  specified). In the case where the list length is less than the number of
194  elements specified by `shape`, the last element in the list will be used
195  to fill the remaining entries.
196
197  The argument `shape` is optional. If present, it specifies the dimensions of
198  the resulting tensor. If not present, the shape of `value` is used.
199
200  If the argument `dtype` is not specified, then the type is inferred from
201  the type of `value`.
202
203  For example:
204
205  ```python
206  # Constant 1-D Tensor populated with value list.
207  tensor = tf.constant([1, 2, 3, 4, 5, 6]) => [1 2 3 4 5 6]
208
209  # Constant 1-D Tensor populated with value list.
210  tensor = tf.constant([1, 2, 3, 4, 5, 6], shape=(2,3))
211       => [[1 2 3], [4 5 6]]
212
213  # Constant 2-D tensor populated with scalar value -1.
214  tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
215                                               [-1. -1. -1.]]
216  ```
217
218  `tf.constant` differs from `tf.fill` in a few ways:
219
220  *   `tf.constant` supports arbitrary constants, not just uniform scalar
221      Tensors like `tf.fill`.
222  *   `tf.constant` creates a `Const` node in the computation graph with the
223      exact value at graph construction time. On the other hand, `tf.fill`
224      creates an Op in the graph that is expanded at runtime.
225  *   Because `tf.constant` only embeds constant values in the graph, it does
226      not support dynamic shapes based on other runtime Tensors, whereas
227      `tf.fill` does.
228
229  Args:
230    value:          A constant value (or list) of output type `dtype`.
231
232    dtype:          The type of the elements of the resulting tensor.
233
234    shape:          Optional dimensions of resulting tensor.
235
236    name:           Optional name for the tensor.
237
238  Returns:
239    A Constant Tensor.
240
241  Raises:
242    TypeError: if shape is incorrectly specified or unsupported.
243  """
244  return _constant_impl(value, dtype, shape, name, verify_shape=False,
245                        allow_broadcast=True)
246
247
248def _constant_impl(
249    value, dtype, shape, name, verify_shape, allow_broadcast):
250  """Implementation of constant."""
251  ctx = context.context()
252  if ctx.executing_eagerly():
253    t = convert_to_eager_tensor(value, ctx, dtype)
254    if shape is None:
255      return t
256    shape = tensor_shape.as_shape(shape)
257    if shape == t.shape:
258      return t
259    if verify_shape:
260      raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape),
261                                                                tuple(t.shape)))
262    num_t = t.shape.num_elements()
263    # TODO(josh11b): Implement shape -> eager tensor conversion.
264    if num_t == shape.num_elements():
265      return _eager_reshape(t, shape.as_list(), ctx)
266    if num_t == 1:
267      if t.dtype == dtypes.bool:
268        # We don't have a Fill kernel for bool dtype on GPU. So we first run
269        # Fill on CPU and then copy to GPU if needed.
270        with ops.device("/device:CPU:0"):
271          x = _eager_fill(shape.as_list(), t.cpu(), ctx)
272        return _eager_identity(x, ctx)
273      else:
274        return _eager_fill(shape.as_list(), t, ctx)
275    raise TypeError("Eager execution of tf.constant with unsupported shape "
276                    "(value has %d elements, shape is %s with %d elements)." %
277                    (num_t, shape, shape.num_elements()))
278  g = ops.get_default_graph()
279  tensor_value = attr_value_pb2.AttrValue()
280  tensor_value.tensor.CopyFrom(
281      tensor_util.make_tensor_proto(
282          value, dtype=dtype, shape=shape, verify_shape=verify_shape,
283          allow_broadcast=allow_broadcast))
284  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
285  const_tensor = g.create_op(
286      "Const", [], [dtype_value.type],
287      attrs={"value": tensor_value,
288             "dtype": dtype_value},
289      name=name).outputs[0]
290  return const_tensor
291
292
293def is_constant(tensor_or_op):
294  if isinstance(tensor_or_op, ops.Tensor):
295    op = tensor_or_op.op
296  else:
297    op = tensor_or_op
298  return op.type == "Const"
299
300
301def _constant_tensor_conversion_function(v, dtype=None, name=None,
302                                         as_ref=False):
303  _ = as_ref
304  return constant(v, dtype=dtype, name=name)
305
306
307ops.register_tensor_conversion_function(
308    (list, tuple), _constant_tensor_conversion_function, 100)
309ops.register_tensor_conversion_function(
310    np.ndarray, _constant_tensor_conversion_function, 100)
311ops.register_tensor_conversion_function(
312    np.generic, _constant_tensor_conversion_function, 100)
313ops.register_tensor_conversion_function(
314    object, _constant_tensor_conversion_function, 200)
315
316
317def _tensor_shape_tensor_conversion_function(s,
318                                             dtype=None,
319                                             name=None,
320                                             as_ref=False):
321  """Function to convert TensorShape to Tensor."""
322  _ = as_ref
323  if not s.is_fully_defined():
324    raise ValueError(
325        "Cannot convert a partially known TensorShape to a Tensor: %s" % s)
326  s_list = s.as_list()
327  int64_value = 0
328  for dim in s_list:
329    if dim >= 2**31:
330      int64_value = dim
331      break
332
333  if dtype is not None:
334    if dtype not in (dtypes.int32, dtypes.int64):
335      raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
336    if dtype == dtypes.int32 and int64_value:
337      raise ValueError("Cannot convert a TensorShape to dtype int32; "
338                       "a dimension is too large (%s)" % int64_value)
339  else:
340    dtype = dtypes.int64 if int64_value else dtypes.int32
341  if name is None:
342    name = "shape_as_tensor"
343  return constant(s_list, dtype=dtype, name=name)
344
345
346ops.register_tensor_conversion_function(
347    tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
348
349
350def _dimension_tensor_conversion_function(d,
351                                          dtype=None,
352                                          name=None,
353                                          as_ref=False):
354  """Function to convert Dimension to Tensor."""
355  _ = as_ref
356  if d.value is None:
357    raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
358  if dtype is not None:
359    if dtype not in (dtypes.int32, dtypes.int64):
360      raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
361  else:
362    dtype = dtypes.int32
363  if name is None:
364    name = "shape_as_tensor"
365  return constant(d.value, dtype=dtype, name=name)
366
367
368ops.register_tensor_conversion_function(
369    tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
370