• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Structured Tensors."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import logging
23import re
24from typing import Callable, Dict, List, Sequence, Tuple, Union
25
26import numpy as np
27
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops.ragged import ragged_factory_ops
40from tensorflow.python.ops.ragged import ragged_tensor
41from tensorflow.python.ops.ragged import row_partition as row_partition_lib
42from tensorflow.python.ops.ragged.row_partition import RowPartition
43from tensorflow.python.util import compat
44from tensorflow.python.util import nest
45
46
47class StructuredTensor(composite_tensor.CompositeTensor):
48  """A multidimensional collection of structures with the same schema.
49
50  A **`StructuredTensor`** is a multi-dimensional collection of ***structures***
51  with the same ***schema***, where:
52
53  * A ***schema*** is a collection of fields, each of which has a name and type.
54  * A ***structure*** maps each field in the schema to a tensor value (which
55    could be a nested StructuredTensor).
56
57  As an important special case, a 1D `StructuredTensor` encodes a 2D table,
58  where columns are heterogeneous `Tensor`s, and rows are the aligned elements
59  in each of those `Tensor`s.
60
61  Internally, StructuredTensors use a "field-major" encoding: for each leaf
62  field, there is a single tensor that stores the value of that field for all
63  structures in the `StructuredTensor`.
64
65  ### Examples
66
67  >>> # A scalar StructuredTensor describing a single person.
68  >>> s1 = StructuredTensor.from_pyval(
69  ...     {"age": 82, "nicknames": ["Bob", "Bobby"]})
70  >>> s1.shape
71  TensorShape([])
72  >>> s1["age"]
73  <tf.Tensor: shape=(), dtype=int32, numpy=82>
74
75  >>> # A vector StructuredTensor describing three people.
76  >>> s2 = StructuredTensor.from_pyval([
77  ...     {"age": 12, "nicknames": ["Josaphine"]},
78  ...     {"age": 82, "nicknames": ["Bob", "Bobby"]},
79  ...     {"age": 42, "nicknames": ["Elmo"]}])
80  >>> s2.shape
81  TensorShape([3])
82  >>> s2[0]["age"]
83  <tf.Tensor: shape=(), dtype=int32, numpy=12>
84
85
86  ### Field Paths
87
88  A *field path* is a tuple of field names, specifying the path to a nested
89  field.
90  """
91
92  #=============================================================================
93  # Common Types
94  #=============================================================================
95  # pylint: disable=invalid-name
96  # Field names work as key, and they can be a sequence to refer to the
97  # sub-levels (embedded) StructuredTensor's.
98  FieldName = Union[str, Sequence[str]]
99
100  # Each field may contain one of the following types of Tensors.
101  FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor']
102
103  # Function that takes a FieldValue as input and returns the transformed
104  # FieldValue.
105  FieldFn = Callable[[FieldValue], FieldValue]
106
107  # pylint: enable=invalid-name
108
109  #=============================================================================
110  # Constructor & Factory Methods
111  #=============================================================================
112
113  def __init__(self, fields, shape, nrows, row_partitions, internal=False):
114    """Private constructor -- use factory methods to create StructuredTensors.
115
116    This constructor builds a `StructuredTensor` from the given attributes,
117    performing minimal validation.
118
119    Args:
120      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
121        `StructuredTensor`.  (This dict is not copied, so the caller must ensure
122        that it does not get mutated via leaked references.)
123      shape: `tf.TensorShape` with statically known rank.
124      nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`.
125      row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`.
126      internal: Private key value, required to ensure that this private
127        constructor is *only* called from the factory methods.
128    """
129    if internal is not _structured_tensor_factory_key:
130      raise ValueError('StructuredTensor constructor is private; please use '
131                       'one of the factory methods instead (e.g., '
132                       'StructuredTensor.from_fields())')
133    assert isinstance(fields, dict), fields
134    assert isinstance(shape, tensor_shape.TensorShape), shape
135    assert nrows is None or isinstance(nrows, ops.Tensor), nrows
136    assert isinstance(row_partitions, tuple), row_partitions
137    self._fields = fields
138    self._shape = shape
139    self._nrows = nrows
140    self._row_partitions = row_partitions
141
142  @classmethod
143  def from_fields(cls,
144                  fields,
145                  shape=(),
146                  nrows=None,
147                  row_partitions=None,
148                  validate=False):
149    """Creates a `StructuredTensor` from a dictionary of fields.
150
151    Args:
152      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
153        `StructuredTensor`, providing the values for individual fields in each
154        structure.  If `shape.rank > 0`, then every tensor in `fields` must have
155        the same shape in the first `shape.rank` dimensions; and that shape must
156        be compatible with `shape`; and `result[i1...iN][key] =
157        fields[key][i1...iN]` (where `N==shape.rank`).
158      shape: A `TensorShape`: static information about the shape of the
159        `StructuredTensor`.  Must have a known `rank`.  Defaults to scalar shape
160        (i.e. `rank=0`).
161      nrows: scalar integer tensor containing the number of rows in this
162        `StructuredTensor`.  Should only be specified if `shape.rank > 0`.
163        Default value is inferred from the `fields` values.  If `fields` is
164        empty, then this must be specified.
165      row_partitions: A list of `RowPartition`s describing the (possibly ragged)
166        shape of this `StructuredTensor`.  Should only be specified if
167        `shape.rank > 1`.  Default value is inferred from the `fields` values.
168        If `fields` is empty, then this must be specified.
169      validate: If true, then add runtime validation ops that check that the
170        field values all have compatible shapes in the outer `shape.rank`
171        dimensions.
172
173    Returns:
174      A `StructuredTensor`.
175
176    Examples:
177
178      >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
179      <StructuredTensor(
180        fields={
181          "x": tf.Tensor(1, shape=(), dtype=int32),
182          "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
183        shape=())>
184
185      >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]},
186      ...                              shape=[2])
187      <StructuredTensor(
188        fields={
189          "bar": tf.Tensor([3 4], shape=(2,), dtype=int32),
190          "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)},
191        shape=(2,))>
192    """
193    shape = tensor_shape.as_shape(shape)
194    rank = shape.rank
195    if rank is None:
196      raise ValueError("StructuredTensor's shape must have known rank.")
197    if not isinstance(fields, dict):
198      raise TypeError('fields must be a dictionary, got %s' %
199                      type(fields).__name__)
200    if rank < 2 and row_partitions:
201      raise ValueError('row_partitions must be None or [] if shape.rank<2')
202    if rank == 0 and nrows is not None:
203      raise ValueError('nrows must be None if shape.rank==0')
204    if row_partitions is not None:
205      row_partitions = tuple(row_partitions)
206      if len(row_partitions) != max(0, rank - 1):
207        raise ValueError('len(row_partitions) must be shape.rank-1')
208    elif rank < 2:
209      row_partitions = ()
210
211    fields = dict(fields)  # Make a private copy.
212    with ops.name_scope(None, 'StructuredTensor', fields.values()):
213
214      # Validate keys and convert field values to tensors.
215      for key, value in fields.items():
216        if not isinstance(key, str):
217          raise TypeError('Unexpected type for key in `fields`: %r' % key)
218        if not _FIELD_NAME_RE.match(key):
219          raise ValueError('Field name %r is not currently allowed.' % key)
220        fields[key] = _convert_to_structured_field_value(value)
221
222      # Determine dtype for row_partitions and nrows.
223      shape_dtype = _find_shape_dtype(fields, nrows, row_partitions)
224      if nrows is not None:
225        nrows = ops.convert_to_tensor(nrows, shape_dtype)
226
227      # Get the static TensorShape for this StructuredTensor.
228      if rank > 0:
229        for key, value in fields.items():
230          if not shape.is_compatible_with(value.shape[:rank]):
231            raise ValueError('Field {} has shape {}, which is incompatible '
232                             'with the shape that was specified or inferred '
233                             'from other fields: {}'.format(
234                                 key, value.shape[:rank], shape))
235          shape = shape.merge_with(value.shape[:rank])
236
237      if rank == 1:
238        # Find a consistent value for `nrows`.
239        static_nrows = tensor_shape.dimension_at_index(shape, 0)
240        for value in fields.values():
241          nrows, static_nrows = _merge_nrows(nrows, static_nrows, value,
242                                             shape_dtype, validate)
243        if nrows is None:
244          if static_nrows.value is None:
245            raise ValueError('nrows must be specified if rank==1 '
246                             'and `fields` is empty.')
247          else:
248            nrows = constant_op.constant(static_nrows.value, shape_dtype)
249
250      if rank > 1:
251        # Find a consistent list of RowPartitions.
252        for value in fields.values():
253          row_partitions = _merge_row_partitions(row_partitions, value, rank,
254                                                 shape_dtype, validate)
255        if row_partitions is None:
256          if not shape.is_fully_defined():
257            raise ValueError('row_partitions must be specified if rank>1 '
258                             'and `fields` is empty.')
259          else:
260            row_partitions = _row_partitions_for_uniform_shape(
261                np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype),
262                shape.rank)
263        assert len(row_partitions) == rank - 1
264        nrows = row_partitions[0].nrows()
265        # Update all field values to use the shared RowPartition objects.
266        fields = dict([(k, _replace_row_partitions(v, row_partitions))
267                       for (k, v) in fields.items()])
268
269    return cls(
270        fields,
271        shape,
272        nrows,
273        row_partitions,
274        internal=_structured_tensor_factory_key)
275
276  def with_updates(
277      self,
278      updates: Dict[FieldName, Union[FieldValue, FieldFn, None]],
279      validate: bool = False
280  ) -> 'StructuredTensor':
281    """Creates a new `StructuredTensor` with the updated fields.
282
283    If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being
284    updated and `v` the new value, then:
285
286    ```
287    result[k] = v              # If (k, v) is in updates and v is a FieldValue
288    result[k] = f(self[k])     # If (k, f) is in updates and f is a FieldFn
289    result[k] = self[k]        # If k is in self.field_names but not in updates
290    ```
291
292    If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each
293    FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is,
294    prefixed with the same shape as the `StructuredTensor`. Then the resulting
295    `StructuredTensor` will have:
296
297    ```
298    result[i1...iN][k] = v[i1...iN]                        # (k, v) in updates
299    result[i1...iN][k] = f(self.field_value(k))[i1...iN]   # (k, f) in updates
300    result[i1...iN][k] = self[i1...iN][k]                  # k not in updates
301    ```
302
303    Note that `result.shape` is always equal to `self.shape` (but the shapes
304    of nested StructuredTensors may be changed if they are updated with new
305    values).
306
307    Args:
308      updates: A dictionary mapping `FieldName` to either a `FieldValue` to be
309        used to update, or a `FieldFn` that will transform the value for the
310        given `FieldName`. `FieldName` can be a string for a direct field, or a
311        sequence of strings to refer to a nested sub-field. `FieldFn` is a
312        function that takes a `FieldValue` as input and should return a
313        `FieldValue`. All other fields are copied over to the new
314        `StructuredTensor`. New `FieldName` can be given (to add new fields),
315        but only to existing `StructuredTensor`, it won't automatically create
316        new nested structures -- but one can create a whole `StructureTensor`
317        sub-structure and set that into an existing structure. If the new value
318        is set to `None`, it is removed.
319      validate: If true, then add runtime validation ops that check that the
320        field values all have compatible shapes in the outer `shape.rank`
321        dimensions.
322
323    Returns:
324      A `StructuredTensor`.
325
326    Raises:
327      `ValueError`: If the any of the `FieldName` keys points to non-existent
328        sub-structures, if parent and child nodes are updated, if shapes
329        change, if a delete update is given for a non-existant field, or if a
330        `FieldFn` transforming function is given for a `FieldName` that doesn't
331        yet exist.
332
333    Examples:
334
335    >>> shoes_us = StructuredTensor.from_pyval([
336    ...    {"age": 12, "nicknames": ["Josaphine"],
337    ...       "shoes": {"sizes": [8.0, 7.5, 7.5]}},
338    ...    {"age": 82, "nicknames": ["Bob", "Bobby"],
339    ...        "shoes": {"sizes": [11.0, 11.5, 12.0]}},
340    ...    {"age": 42, "nicknames": ["Elmo"],
341    ...        "shoes": {"sizes": [9.0, 9.5, 10.0]}}])
342    >>> def us_to_europe(t):
343    ...   return tf.round(t * 2.54 + 17.0)  # Rough approximation.
344    >>> shoe_sizes_key = ("shoes", "sizes")
345    >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe})
346    >>> shoes_eu.field_value(shoe_sizes_key)
347    <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0],
348    [40.0, 41.0, 42.0]]>
349    """
350    updates_items = [(_normalize_field_name_to_tuple(name), value)
351                     for name, value in updates.items()]
352
353    # Sort by keys and check for updates of both parent and child nodes.
354    updates_items = sorted(updates_items)
355    for i in range(1, len(updates_items)):
356      # Parent of a node would precede node in the sorted order.
357      name = updates_items[i][0]  # item[0] is the name, item[1] is the value.
358      prev_name = updates_items[i - 1][0]
359      if name[:len(prev_name)] == prev_name:
360        raise ValueError(
361            '`StructuredTensor.with_updates` does not allow both parent and '
362            'child nodes to be updated: parent={}, child={}. If needed you can '
363            'update child nodes in the parent update value.'.format(
364                prev_name, name))
365    return self._with_updates_impl((), updates_items, validate)
366
367  def _with_updates_impl(
368      self,
369      error_prefix: Tuple[str],
370      updates: List[Tuple[FieldName, Union[FieldValue, FieldFn]]],
371      validate: bool) -> 'StructuredTensor':
372    """Recursive part of `with_updates` implementation."""
373    # Get current fields.
374    new_fields = dict(self._fields)
375
376    # Convert field name to string with full path for error messages.
377    def name_fullpath(name: Sequence[str]) -> str:
378      return str(error_prefix + (name,))
379
380    # Apply value if a function or the value itself.
381    def apply_value(name: str, value: Union['FieldValue',
382                                            'FieldFn']) -> 'FieldValue':
383      if callable(value):
384        # `value` is actually a transforming function.
385        if name not in new_fields:
386          raise ValueError(
387              '`StructuredTensor.with_updates` cannot update the field {} '
388              'because a transforming function was given, but that field '
389              'does not already exist.'.format(name_fullpath(name)))
390        value = value(new_fields[name])
391      return value
392
393    # Merge updates.
394    for name, value in updates:
395      if not name or not name[0]:
396        raise ValueError(
397            '`StructuredTensor.with_updates` does not allow empty names '
398            '{}.'.format(name_fullpath(name)))
399
400      if len(name) == 1:
401        name = name[0]
402        if value is None:
403          if name not in new_fields:
404            raise ValueError(
405                '`StructuredTensor.with_updates` cannot delete field '
406                '{} because it is not present.'.format(name_fullpath(name)))
407          new_fields.pop(name)
408        else:
409          new_fields[name] = apply_value(name, value)
410      else:
411        # Recursive
412        prefix = name[0]
413        suffix = name[1:]
414        if prefix not in new_fields:
415          raise ValueError(
416              '`StructuredTensor.with_updates` cannot create new sub-field '
417              '{} if parent field {} is not set.'.format(
418                  error_prefix + tuple(name), name_fullpath(prefix)))
419        current_value = new_fields[prefix]
420        if not isinstance(current_value, StructuredTensor):
421          raise ValueError(
422              '`StructuredTensor.with_updates` cannot create new sub-field '
423              '{} if parent structure {} is not a `StructuredTensor` that '
424              'can contain sub-structures -- it is a `{}`.'.format(
425                  error_prefix + tuple(name), name_fullpath(prefix),
426                  type(current_value)))
427        one_update = [(suffix, value)]
428
429        # Accessing protected member in recursion.
430        # FutureWork: optimize by aggregating the recursions, instead of
431        #   calling one at a time.
432        # pylint: disable=protected-access
433        value = current_value._with_updates_impl(error_prefix + (prefix,),
434                                                 one_update, validate)
435        # pylint: enable=protected-access
436        new_fields[prefix] = value
437
438    # TODO(edloper): When validate=True, only validate the modified fields.
439    try:
440      return StructuredTensor.from_fields(
441          new_fields,
442          shape=self.shape,
443          row_partitions=self._row_partitions,
444          nrows=self._nrows,
445          validate=validate)
446
447    except ValueError as e:
448      msg = '`StructuredTensor.with_updates` failed'
449      if error_prefix:
450        msg = '{} for field {}'.format(msg, error_prefix)
451      raise ValueError('{}: {}'.format(msg, e))
452
453  def _promote_helper(self, source_path, new_parent_path):
454    """Creates a promoted field without adding it to the structure.
455
456    Args:
457      source_path: the source path in the structured tensor.
458      new_parent_path: the new parent path. Must be a prefix of source_path.
459
460    Returns:
461      a composite tensor of source_path promoted.
462    Raises:
463      ValueError: if the shape of the field is unknown and the right strategy
464      cannot be determined.
465    """
466    current_field = self.field_value(source_path)
467    new_parent_rank = self.field_value(new_parent_path).rank
468    parent_rank = self.field_value(source_path[:-1]).rank
469    if new_parent_rank == parent_rank:
470      return current_field
471    current_field_rank = current_field.shape.rank
472    if current_field_rank is None:
473      raise ValueError('Cannot determine if dimensions should be merged.')
474    inner_dim = min(parent_rank, current_field_rank - 1)
475    if inner_dim <= new_parent_rank:
476      return current_field
477    return _merge_dims_generic(current_field, new_parent_rank, inner_dim)
478
479  def promote(self, source_path, new_name):
480    """Promotes a field, merging dimensions between grandparent and parent.
481
482    >>> d = [
483    ...  {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]},
484    ...  {'docs': [{'tokens':[7]}]}]
485    >>> st = StructuredTensor.from_pyval(d)
486    >>> st2 =st.promote(('docs','tokens'), 'docs_tokens')
487    >>> st2[0]['docs_tokens']
488    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
489    >>> st2[1]['docs_tokens']
490    <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)>
491
492    Args:
493      source_path: the path of the field or substructure to promote; must have
494        length at least 2.
495      new_name: the name of the new field (must be a string).
496
497    Returns:
498      a modified structured tensor with the new field as a child of the
499      grandparent of the source_path.
500
501    Raises:
502      ValueError: if source_path is not a list or a tuple or has a length
503        less than two, or new_name is not a string, or the rank
504        of source_path is unknown and it is needed.
505    """
506    if not isinstance(new_name, str):
507      raise ValueError('new_name is not a string')
508    if not isinstance(source_path, (list, tuple)):
509      raise ValueError('source_path must be a list or tuple')
510
511    if len(source_path) < 2:
512      raise ValueError('source_path must have length at least two')
513
514    grandparent_path = source_path[:-2]
515    new_field = self._promote_helper(source_path, grandparent_path)
516    new_path = grandparent_path + (new_name,)
517    return self.with_updates({new_path: new_field})
518
519  #=============================================================================
520  # Properties
521  #=============================================================================
522
523  @property
524  def rank(self):
525    """The rank of this StructuredTensor.  Guaranteed not to be `None`."""
526    return self._shape.rank
527
528  @property
529  def shape(self):
530    """The static shape of this StructuredTensor.
531
532    The returned `TensorShape` is guaranteed to have a known rank, but the
533    individual dimension sizes may be unknown.
534
535    Returns:
536      `tf.TensorShape`
537    """
538    return self._shape
539
540  # TODO(edloper): Make this a func instead of a property?  Or make nrows
541  # a property instead of a func?  Seems like these should be consistent.
542  @property
543  def row_partitions(self):
544    """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`.
545
546    When `self.rank <= 1`, this tuple will be empty.
547
548    When `self.rank > 1`, these `RowPartitions` define the shape of the
549    `StructuredTensor` by describing how a flat (1D) list of structures can be
550    repeatedly partitioned to form a higher-dimensional object.  In particular,
551    the flat list is first partitioned into sublists using `row_partitions[-1]`,
552    and then those sublists are further partitioned using `row_partitions[-2]`,
553    etc.  The following examples show the row partitions used to describe
554    several different `StructuredTensor`, each of which contains 8 copies of
555    the same structure (`x`):
556
557    >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']}       # shape = [] (scalar)
558
559    >>> s1 = [[x, x, x, x], [x, x, x, x]]              # shape = [2, 4]
560    >>> StructuredTensor.from_pyval(s1).row_partitions
561    (tf.RowPartition(row_splits=tf.Tensor([0 4 8], shape=(3,),
562                                          dtype=int64)),)
563
564    >>> s2 = [[x, x], [x, x], [x, x], [x, x]]          # shape = [4, 2]
565    >>> StructuredTensor.from_pyval(s2).row_partitions
566    (tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,),
567                                          dtype=int64)),)
568
569    >>> s3 = [[x, x, x], [], [x, x, x, x], [x]]        # shape = [2, None]
570    >>> StructuredTensor.from_pyval(s3).row_partitions
571    (tf.RowPartition(row_splits=tf.Tensor([0 3 3 7 8], shape=(5,),
572                                          dtype=int64)),)
573
574    >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]]      # shape = [2, 2, 2]
575    >>> StructuredTensor.from_pyval(s4).row_partitions
576    (tf.RowPartition(row_splits=tf.Tensor([0 2 4], shape=(3,), dtype=int64)),
577     tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,),
578                                          dtype=int64)))
579
580
581    >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]]  # shape = [3, None, None]
582    >>> StructuredTensor.from_pyval(s5).row_partitions
583    (tf.RowPartition(row_splits=tf.Tensor([0 2 3 5], shape=(4,), dtype=int64)),
584     tf.RowPartition(row_splits=tf.Tensor([0 2 3 5 7 8], shape=(6,),
585                                          dtype=int64)))
586
587    Note that shapes for nested fields (such as `x['b']` in the above example)
588    are not considered part of the shape of a `StructuredTensor`, and are not
589    included in `row_partitions`.
590
591    If this `StructuredTensor` has a ragged shape (i.e., if any of the
592    `row_partitions` is not uniform in size), then all fields will be encoded
593    as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s
594    used to define their outermost `self.rank` dimensions.
595
596    Returns:
597      A `tuple` of `RowPartition` objects with length `self.rank - 1`
598      (or `0` if `self.rank < 2`)
599
600    """
601    return self._row_partitions
602
603  def nrows(self):
604    """The number of rows in this StructuredTensor (if rank>0).
605
606    This means the length of the outer-most dimension of the StructuredTensor.
607
608    Notice that if `self.rank > 1`, then this equals the number of rows
609    of the first row partition. That is,
610    `self.nrows() == self.row_partitions[0].nrows()`.
611
612    Otherwise `self.nrows()` will be the first dimension of the field values.
613
614    Returns:
615      A scalar integer `Tensor` (or `None` if `self.rank == 0`).
616    """
617    return self._nrows
618
619  def _is_eager(self):
620    """True if all fields are composed of eager tensors."""
621    tensors = nest.flatten(self, expand_composites=True)
622    return all(isinstance(t, ops.EagerTensor) for t in tensors)
623
624  #=============================================================================
625  # Encoding
626  #=============================================================================
627
628  def field_names(self):
629    """Returns the string field names for this `StructuredTensor`."""
630    return tuple(self._fields.keys())
631
632  def field_value(self, field_name):
633    """Returns the tensor value for the specified field or path.
634
635    If `field_name` is a `string`, then it names a field directly owned by this
636    `StructuredTensor`.  If this `StructuredTensor` has shape `[D1...DN]`, then
637    the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice
638    `result[d1...dN]` contains the field value for the structure at
639    `self[d1...dN]`.
640
641    If `field_name` is a `tuple` of `string`, then it specifies a path to a
642    field owned by nested `StructuredTensor`.  In particular,
643    `struct.field_value((f1, f2, ..., fN))` is equivalent to
644    `struct.field_value(f1).field_value(f2)....field_value(fN)`
645
646    Args:
647      field_name: `string` or `tuple` of `string`: The field whose values should
648        be returned.
649
650    Returns:
651      `Tensor`, `StructuredTensor`, or `RaggedTensor`.
652
653    Raises:
654      KeyError: If the given field_name is not found.
655    """
656    if isinstance(field_name, (list, tuple)):
657      value = self
658      for f in field_name:
659        if not isinstance(value, StructuredTensor):
660          raise KeyError('Field path {} not found in {}'.format(
661              field_name, self))
662        value = value.field_value(f)
663      return value
664    return self._fields[field_name]
665
666  #=============================================================================
667  # Operators
668  #=============================================================================
669
670  # TODO(edloper): Add support for ellipsis and/or newaxis?
671  def __getitem__(self, key):
672    """Returns the specified piece of this StructuredTensor.
673
674    * If `struct_tensor` is scalar (i.e., a single structure), then
675      `struct_tensor[f]` returns the value of field `f` (where `f` must be a
676      string).
677
678    * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional
679      tensor of structures), `struct_tensor[i]` selects an element or slice of
680      the tensor using standard Python semantics (e.g., negative values index
681      from the end).  `i` may have any of the following types:
682
683      * `int` constant
684      * `string` constant
685      * scalar integer `Tensor`
686      * `slice` containing integer constants and/or scalar integer
687        `Tensor`s
688
689    #### Multidimensional indexing
690
691    `StructuredTensor` supports multidimensional indexing.  I.e., `key` may be a
692    `tuple` of values, indexing or slicing multiple dimensions at once.  For
693    example, if `people` is a vector of structures, each of which has a vector-
694    valued `names` field, then `people[3, 'names', 0]` is equivalent to
695    `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly
696    ragged) matrix of names, with shape `[num_people, num_names_per_person]`.
697
698    Args:
699      key: Indicates which piece of the StructuredTensor to return.
700
701    Returns:
702      A `Tensor`, `StructuredTensor`, or `RaggedTensor`.
703    """
704    if isinstance(key, list):
705      key = tuple(key)
706    elif not isinstance(key, tuple):
707      key = (key,)
708    if not key:
709      return self
710
711    if self._shape.rank == 0:
712      return self._scalar_getitem(key)
713    else:
714      return self._tensor_getitem(key)
715
716  def _scalar_getitem(self, key):
717    if (isinstance(key[0], slice) and key[0].start is None and
718        key[0].stop is None and key[0].step is None):
719      fields = dict((field_name, field_value.__getitem__(key[1:]))
720                    for (field_name, field_value) in self._fields.items())
721      return StructuredTensor.from_fields(fields, self._shape)
722
723    elif not isinstance(key[0], compat.bytes_or_text_types):
724      raise ValueError('Key for indexing a StructuredTensor must be a '
725                       "string or a full slice (':')")
726
727    return self._fields[key[0]].__getitem__(key[1:])
728
729  def _tensor_getitem(self, key):
730    rank = self._shape.rank
731    if len(key) <= rank:
732      new_fields = dict((field_name, field_value.__getitem__(key))
733                        for (field_name, field_value) in self._fields.items())
734      result_shape = self.shape.as_list()
735      for d, k in enumerate(key):
736        if isinstance(k, slice):
737          if not (k.start is None and k.stop is None and k.step is None):
738            # TODO(edloper): Better static shape analysis here.
739            result_shape[d] = None
740        elif isinstance(k, (int, ops.Tensor)):
741          result_shape[d] = -1  # mark for deletion
742        elif k is None:
743          raise ValueError('Slicing not supported for tf.newaxis')
744        else:
745          # Ellipsis, tf.newaxis:
746          raise ValueError('Slicing not supported for %r' % k)
747      result_shape = [d for d in result_shape if d != -1]
748      return StructuredTensor.from_fields(new_fields, result_shape)
749
750    else:
751      if not isinstance(key[rank], compat.bytes_or_text_types):
752        # TODO(edloper): Also support full slice here?
753        raise ValueError('Key for indexing a StructuredTensor must be a string')
754      return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:])
755
756  def __repr__(self):
757    fields = sorted(self._fields.items())
758    fields = ((k, str(v).replace('\n', '\n            ')) for k, v in fields)
759    fields = ('"{}": {}'.format(k, v) for k, v in fields)
760    dict_repr = ',\n        '.join(fields)
761    return ('<StructuredTensor(\n'
762            '    fields={\n'
763            '        %s},\n'
764            '    shape=%s)>' % (dict_repr, self._shape))
765
766  #=============================================================================
767  # Conversion
768  #=============================================================================
769
770  def to_pyval(self):
771    """Returns this StructuredTensor as a nested Python dict or list of dicts.
772
773    Converts this `StructuredTensor` to a nested python value:
774
775    * `StructTensors` with `rank=0` are converted into a dictionary, with an
776      entry for each field.  Field names are used as keys and field values are
777      converted to python values.  In particular:
778
779      * Scalar Tensor fields are converted to simple values (such as
780        `int` or `float` or `string`)
781      * Non-scalar Tensor fields and RaggedTensor fields are converted to
782        nested lists of simple values.
783      * StructuredTensor fields are converted recursively using `to_pyval`.
784
785    * `StructTensors` with `rank>0` are converted to nested python `list`s,
786      containing one dictionary for each structure (where each structure's
787      dictionary is defined as described above).
788
789    Requires that all fields are Eager tensors.
790
791    >>> StructuredTensor.from_fields(
792    ...     {'a': [1, 2, 3]}, [3]).to_pyval()
793    [{'a': 1}, {'a': 2}, {'a': 3}]
794
795    Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
796
797    Returns:
798      A nested Python dict or list of dicts.
799    """
800    if not self._is_eager():
801      raise ValueError(
802          'StructuredTensor.to_pyval() is only supported in eager mode.')
803
804    # Convert each field value to a nested list.
805    result = {}
806    for (key, value) in self._fields.items():
807      if isinstance(value, ops.EagerTensor):
808        value = value.numpy()
809      if isinstance(value, np.ndarray):
810        value = value.tolist()
811      elif isinstance(value, ragged_tensor.RaggedTensor):
812        value = value.to_list()
813      elif isinstance(value, StructuredTensor):
814        value = value.to_pyval()
815      # TODO(edloper): Throw an exception if value is an unexpected type.
816      result[key] = value
817
818    # If rank>0, then re-group each value from dict-of-list to list-of-dict.
819    if len(self._shape) > 0:  # pylint: disable=g-explicit-length-test
820      if not result:  # special-case for StructuredTensors w/ no fields.
821        return _empty_dict_pylist_from_row_partitions(self._row_partitions,
822                                                      self._nrows)
823      return _pyval_field_major_to_node_major(
824          list(result.keys()), list(result.values()), self._shape.rank)
825    else:
826      return result
827
828  @classmethod
829  def from_pyval(cls, pyval, typespec=None):
830    """Constructs a StructuredTensor from a nested Python structure.
831
832    >>> StructuredTensor.from_pyval(
833    ...     {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]})
834    <StructuredTensor(
835        fields={
836          "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32),
837          "b": <tf.RaggedTensor [[4, 5], [6, 7]]>},
838        shape=())>
839
840    Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
841
842    Args:
843      pyval: The nested Python structure that should be used to create the new
844        `StructuredTensor`.
845      typespec: A `StructuredTensorSpec` specifying the expected type for each
846        field. If not specified, then all nested dictionaries are turned into
847        StructuredTensors, and all nested lists are turned into Tensors (if
848        rank<2) or RaggedTensors (if rank>=2).
849
850    Returns:
851      A `StructuredTensor`.
852    """
853    return cls._from_pyval(pyval, typespec, ())
854
855  @classmethod
856  def _from_pyval(cls, pyval, typespec, path_so_far):
857    """Helper function for from_pyval.
858
859
860    Args:
861      pyval: The nested Python structure that should be used to create the new
862        `StructuredTensor`.
863      typespec: A `StructuredTensorSpec` specifying the expected type for each
864        field. If not specified, then all nested dictionaries are turned into
865        StructuredTensors, and all nested lists are turned into Tensors (if
866        rank<2) or RaggedTensors (if rank>=2).
867      path_so_far: the path of fields that led here (for error messages).
868
869    Returns:
870      A `StructuredTensor`.
871    """
872    if isinstance(pyval, dict):
873      return cls._from_pydict(pyval, typespec, path_so_far)
874    elif isinstance(pyval, (list, tuple)):
875      keys = set()
876      rank = _pyval_find_struct_keys_and_depth(pyval, keys)
877      if rank is not None:
878        return cls._from_pylist_of_dict(pyval, keys, rank, typespec,
879                                        path_so_far)
880      else:
881        return cls._from_pylist_of_value(pyval, typespec, path_so_far)
882    else:
883      return cls._from_pyscalar(pyval, typespec, path_so_far)
884
885  @classmethod
886  def _from_pydict(cls, pyval, typespec, path_so_far):
887    """Converts python dictionary `pyval` to a StructuredTensor with rank=0."""
888    if typespec is None:
889      fields = dict((k, cls._from_pyval(v, None, path_so_far + (k,)))
890                    for (k, v) in pyval.items())
891    else:
892      spec_shape = typespec._shape  # pylint: disable=protected-access
893      field_specs = typespec._field_specs  # pylint: disable=protected-access
894      if not (isinstance(typespec, StructuredTensorSpec) and
895              spec_shape.rank == 0 and set(pyval) == set(field_specs)):
896        raise ValueError('Value at %r does not match typespec: %r vs %r' %
897                         (path_so_far, pyval, typespec))
898      fields = dict((k, cls._from_pyval(v, field_specs[k], path_so_far + (k,)))
899                    for (k, v) in pyval.items())
900    return StructuredTensor.from_fields(fields=fields, shape=(), validate=False)
901
902  @classmethod
903  def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far):
904    """Converts python list `pyval` to a StructuredTensor with rank>1."""
905    fields = dict((key, []) for key in keys)
906    for child in pyval:
907      _pyval_update_fields(child, fields, 1)
908    if typespec is None:
909      shape = tensor_shape.TensorShape([None] * rank)
910      for (key, target) in fields.items():
911        fields[key] = cls._from_pyval(target, None, path_so_far + (key,))
912    else:
913      field_specs = typespec._field_specs  # pylint: disable=protected-access
914      if ((not isinstance(typespec, StructuredTensorSpec)) or
915          (set(fields) - set(field_specs))):
916        raise ValueError('Value at %r does not match typespec: %r vs %r' %
917                         (path_so_far, pyval, typespec))
918      shape = typespec._shape
919      if shape.rank < rank:
920        raise ValueError('Value at %r does not match typespec (rank mismatch): '
921                         '%r vs %r' % (path_so_far, pyval, typespec))
922      for (key, spec) in field_specs.items():
923        fields[key] = cls._from_pyval(
924            fields.get(key, []), spec, path_so_far + (key,))
925    try:
926      if not fields and typespec is None:
927        # TODO(b/183245576): handle cases where the typespec is known
928        # but the dictionary is empty.
929        return StructuredTensor._from_pylist_of_empty_dict(pyval, rank)
930      return StructuredTensor.from_fields(
931          fields=fields, shape=shape, validate=False)
932    except Exception as exc:
933      raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
934
935  @classmethod
936  def _from_pylist_of_empty_dict(cls, pyval, rank):
937    """Converts a pylist of empty dictionaries to StructuredTensors."""
938    if rank == 0:
939      return StructuredTensor.from_fields(fields={}, shape=(), validate=False)
940    elif rank == 1:
941      nrows = len(pyval)
942      shape = (nrows,)
943      return StructuredTensor.from_fields(fields={}, shape=shape, nrows=nrows)
944    elif rank > 1:
945      ragged_zeros = ragged_factory_ops.constant(_dicts_to_zeros(pyval))
946      nrows = len(pyval)
947      shape = tensor_shape.TensorShape([len(pyval)] + ([None] * (rank - 1)))
948      return StructuredTensor.from_fields(
949          fields={},
950          shape=shape,
951          row_partitions=ragged_zeros._nested_row_partitions,  # pylint:disable=protected-access
952          nrows=nrows)
953
954  @classmethod
955  def _from_pylist_of_value(cls, pyval, typespec, path_so_far):
956    """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1."""
957    if typespec is None:
958      try:
959        return ragged_factory_ops.constant(pyval)
960      except Exception as exc:
961        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
962    elif isinstance(typespec, tensor_spec.TensorSpec):
963      try:
964        result = constant_op.constant(pyval, typespec.dtype)
965      except Exception as exc:
966        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
967      if not typespec.shape.is_compatible_with(result.shape):
968        raise ValueError('Value at %r does not match typespec: %r vs %r' %
969                         (path_so_far, typespec, pyval))
970      return result
971    elif isinstance(typespec, ragged_tensor.RaggedTensorSpec):
972      # pylint: disable=protected-access
973      try:
974        return ragged_factory_ops.constant(
975            pyval,
976            dtype=typespec._dtype,
977            ragged_rank=typespec._ragged_rank,
978            row_splits_dtype=typespec._row_splits_dtype,
979            inner_shape=typespec._shape[typespec._ragged_rank + 1:])
980      except Exception as exc:
981        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
982    elif isinstance(typespec, StructuredTensorSpec):
983      empty_rank = _pyval_empty_list_depth(pyval)
984      if empty_rank is None:
985        raise ValueError('Value at %r does not match typespec: %r vs %r' %
986                         (path_so_far, typespec, pyval))
987      else:
988        return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec,
989                                        path_so_far)
990    else:
991      raise ValueError('Value at %r does not match typespec: %r vs %r' %
992                       (path_so_far, typespec, pyval))
993
994  @classmethod
995  def _from_pyscalar(cls, pyval, typespec, path_so_far):
996    """Converts python scalar value `pyval` to a Tensor."""
997    if typespec is None:
998      try:
999        return constant_op.constant(pyval)
1000      except Exception as exc:
1001        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1002    else:
1003      if not (isinstance(typespec, tensor_spec.TensorSpec) and
1004              typespec.shape.rank == 0):
1005        raise ValueError('Value at %r does not match typespec: %r vs %r' %
1006                         (path_so_far, typespec, pyval))
1007      # TODO(edloper): Check that typespec.shape matches.
1008      return constant_op.constant(pyval, typespec.dtype)
1009
1010  #=============================================================================
1011  # Transforms
1012  #=============================================================================
1013
1014  # TODO(edloper): Add a 'validate' option here?
1015  # TODO(edloper): Unify nomenclature with RaggedTensor.  Should RaggedTensor
1016  # have a partition_outer_dimension method?
1017  def partition_outer_dimension(self, row_partition):
1018    """Partitions the outer dimension of this StructuredTensor.
1019
1020    Returns a new `StructuredTensor` with the same values as `self`, where
1021    the outer dimension is partitioned into two (possibly ragged) dimensions.
1022    Requires that this StructuredTensor have an outer dimension (i.e.,
1023    `self.shape.rank > 0`).
1024
1025    >>> st = StructuredTensor.from_pyval(
1026    ...     [{'foo': 12}, {'foo': 33}, {'foo': 99}])
1027    >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1028    >>> st.partition_outer_dimension(partition)
1029    <StructuredTensor(
1030      fields={
1031        "foo": <tf.RaggedTensor [[12, 33], [], [99]]>},
1032      shape=(3, None))>
1033
1034    Args:
1035      row_partition: A `RowPartition`.
1036
1037    Returns:
1038      A `StructuredTensor` with rank `values.rank + 1`.
1039    """
1040    if not isinstance(row_partition, RowPartition):
1041      raise TypeError('row_partition must be a RowPartition.')
1042    if self.shape.rank == 0:
1043      raise ValueError('Shape %s must have rank at least 1' % self.shape)
1044    return _partition_outer_dimension(self, row_partition)
1045
1046  def merge_dims(self, outer_axis, inner_axis):
1047    """Merges outer_axis...inner_axis into a single dimension.
1048
1049    Returns a copy of this RaggedTensor with the specified range of dimensions
1050    flattened into a single dimension, with elements in row-major order.
1051
1052    >>> st = StructuredTensor.from_pyval(
1053    ...     [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]])
1054    >>> st.merge_dims(0, 1)
1055    <StructuredTensor(
1056      fields={
1057        "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)},
1058      shape=(3,))>
1059
1060    Args:
1061      outer_axis: `int`: The first dimension in the range of dimensions to
1062        merge. May be negative (to index from the last dimension).
1063      inner_axis: `int`: The last dimension in the range of dimensions to merge.
1064        May be negative (to index from the last dimension).
1065
1066    Returns:
1067      A copy of this tensor, with the specified dimensions merged into a
1068      single dimension.  The shape of the returned tensor will be
1069      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1070      is the total number of slices in the merged dimensions.
1071    """
1072    outer_axis = array_ops.get_positive_axis(
1073        outer_axis,
1074        self.shape.rank,
1075        axis_name='outer_axis',
1076        ndims_name='rank(self)')
1077    inner_axis = array_ops.get_positive_axis(
1078        inner_axis,
1079        self.shape.rank,
1080        axis_name='inner_axis',
1081        ndims_name='rank(self)')
1082    if not outer_axis <= inner_axis:
1083      raise ValueError('Expected outer_axis (%d) to be less than or equal to '
1084                       'inner_axis (%d)' % (outer_axis, inner_axis))
1085    return _merge_dims(self, outer_axis, inner_axis)
1086
1087  #=============================================================================
1088  # Composite Tensor
1089  #=============================================================================
1090
1091  @property
1092  def _type_spec(self):
1093    return StructuredTensorSpec.from_value(self)
1094
1095
1096@type_spec.register('tf.StructuredTensorSpec')
1097class StructuredTensorSpec(type_spec.BatchableTypeSpec):
1098  """Type specification for `StructuredTensor`s."""
1099
1100  __slots__ = ['_shape', '_field_specs']
1101
1102  def __init__(self, shape, field_specs):
1103    """Build a type specification for a StructuredTensor.
1104
1105    Args:
1106      shape: The shape of the StructuredTensor.  shape.rank must not be None.
1107      field_specs: A dictionary mapping from field name to TypeSpec, specifying
1108        the tensor type used to encode each field. These TypeSpecs should
1109        specify the type of the entire field (including outer dimensions which
1110        correspond to `shape`).  For example, if `shape=[2, 3]`, and field 'x'
1111        contains an int32 vector of size `10` for each structure, then
1112        `field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`.
1113    """
1114    shape = tensor_shape.as_shape(shape)
1115
1116    # Perform a few sanity checks on the inputs.
1117    if shape.rank is None:
1118      raise TypeError("StructuredTensor's shape must have known rank.")
1119    if not isinstance(field_specs, dict):
1120      raise TypeError('field_specs must be a dictionary.')
1121    for key, value in field_specs.items():
1122      if not isinstance(key, str):
1123        raise TypeError('field_specs must be a dictionary with string keys.')
1124      if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec,
1125                                ragged_tensor.RaggedTensorSpec)):
1126        raise TypeError('field_specs must be a dictionary with '
1127                        'TypeSpec values.')
1128
1129    self._shape = shape
1130    self._field_specs = dict(field_specs)
1131
1132  @property
1133  def value_type(self):
1134    return StructuredTensor
1135
1136  def _to_components(self, value):
1137    nrows = () if value.nrows() is None else value.nrows()
1138    return (value._fields, nrows, value.row_partitions)
1139
1140  def _from_components(self, components):
1141    if isinstance(components, dict):
1142      logging.warning('Loading deprecated encoding for StructuredTensorSpec.')
1143      return StructuredTensor.from_fields(components, self._shape,
1144                                          validate=False)
1145    elif not isinstance(components[0], dict):
1146      logging.warning('Loading deprecated encoding for StructuredTensorSpec.')
1147      fields = {}
1148      nrows, row_partitions = components
1149      if isinstance(nrows, tuple) and not nrows:
1150        nrows = None  # empty rank-0 structured tensor
1151      return StructuredTensor.from_fields(fields, self._shape, nrows=nrows,
1152                                          row_partitions=row_partitions,
1153                                          validate=False)
1154
1155    (fields, nrows, row_partitions) = components
1156    if isinstance(nrows, tuple) and not nrows:
1157      nrows = None  # empty rank-0 structured tensor
1158    return StructuredTensor(fields, self._shape, nrows, row_partitions,
1159                            internal=_structured_tensor_factory_key)
1160
1161  @property
1162  def _component_specs(self):
1163    if self._shape.rank == 0:
1164      nrows_spec = ()
1165    else:
1166      nrows_spec = tensor_spec.TensorSpec([], dtypes.int64)
1167
1168    row_partition_specs = ((row_partition_lib.RowPartitionSpec(),)
1169                           * (self._shape.rank - 1))
1170    return (self._field_specs, nrows_spec, row_partition_specs)
1171
1172  @classmethod
1173  def from_value(cls, value):
1174    field_specs = dict((k, type_spec.type_spec_from_value(v))
1175                       for (k, v) in value._fields.items())
1176    return cls(value.shape, field_specs)
1177
1178  def _serialize(self):
1179    return (self._shape, self._field_specs)
1180
1181  def _batch(self, batch_size):
1182    # pylint: disable=protected-access
1183    return StructuredTensorSpec(
1184        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
1185        dict((k, v._batch(batch_size)) for (k, v) in self._field_specs.items()))
1186
1187  def _unbatch(self):
1188    # pylint: disable=protected-access
1189    return StructuredTensorSpec(
1190        self._shape[1:],
1191        dict((k, v._unbatch()) for (k, v) in self._field_specs.items()))
1192
1193  @property
1194  def _flat_tensor_specs(self):
1195    # pylint: disable=protected-access
1196    result = []
1197    for _, field_spec in sorted(self._field_specs.items(), key=lambda t: t[0]):
1198      result.extend(field_spec._flat_tensor_specs)
1199    return result
1200
1201  def _to_tensor_list(self, value):
1202    return self._to_tensor_list_internal(value, batched=False)
1203
1204  def _to_batched_tensor_list(self, value):
1205    return self._to_tensor_list_internal(value, batched=True)
1206
1207  def _from_compatible_tensor_list(self, tensor_list):
1208    # pylint: disable=protected-access
1209    fields = {}
1210    pos = 0
1211    for field_name, field_spec in sorted(
1212        self._field_specs.items(), key=lambda t: t[0]):
1213      num_tensors_for_field = len(field_spec._flat_tensor_specs)
1214      field_tensors = tensor_list[pos:pos + num_tensors_for_field]
1215      fields[field_name] = field_spec._from_compatible_tensor_list(
1216          field_tensors)
1217      pos += num_tensors_for_field
1218    return StructuredTensor.from_fields(fields, self._shape)
1219
1220  def _to_tensor_list_internal(self, value, batched):
1221    """Returns a dict whose entries are each field's (batched) tensor_list.
1222
1223    If a field is a StructuredTensor, then its entry will be a dict,
1224    recursively.
1225
1226    Args:
1227      value: A StructuredTensor (conforming to `self`).
1228      batched: A boolean. if True, produce `batched_tensor_list` for each field
1229        otherwise produce `tensor_list`.
1230
1231    Returns:
1232      A dict.
1233    """
1234    result = []
1235    for field_name, field_spec in sorted(
1236        self._field_specs.items(), key=lambda t: t[0]):
1237      # pylint: disable=protected-access
1238      field_value = value._fields[field_name]
1239      if batched:
1240        result.extend(field_spec._to_batched_tensor_list(field_value))
1241      else:
1242        result.extend(field_spec._to_tensor_list(field_value))
1243
1244    return result
1245
1246
1247# Regular expression used to determine whether a string is a valid field name.
1248# Note: we plan to relax (or possibly eliminate) this in the future; you
1249# should not rely on the fact that some field names are currently disallowed.
1250_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$')
1251
1252#=============================================================================
1253# Helper funtions
1254#=============================================================================
1255# TODO(edloper): Move some of these helpers to row_partition.py?
1256
1257
1258def _convert_to_structured_field_value(value):
1259  """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
1260  if isinstance(value,
1261                (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
1262    return value
1263  elif ragged_tensor.is_ragged(value):
1264    return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
1265  else:
1266    try:
1267      return ops.convert_to_tensor(value)
1268    except (ValueError, TypeError):
1269      raise TypeError('Unexpected type for value in `fields`: %r' % value)
1270
1271
1272def _find_shape_dtype(fields, nrows, row_partitions):
1273  """Return a consistent dtype for fields, nrows, & row_partitions."""
1274  shape_dtypes = set()
1275  for value in fields.values():
1276    if isinstance(value, ragged_tensor.RaggedTensor):
1277      shape_dtypes.add(value.row_splits.dtype)
1278    elif isinstance(value, StructuredTensor) and value.rank > 0:
1279      shape_dtypes.add(value.nrows().dtype)
1280  if isinstance(nrows, ops.Tensor):
1281    shape_dtypes.add(nrows.dtype)
1282  if row_partitions is not None:
1283    for partition in row_partitions:
1284      shape_dtypes.add(partition.dtype)
1285  if len(shape_dtypes) > 1:
1286    raise ValueError('field values have incompatible row_partition dtypes.')
1287  elif shape_dtypes:
1288    return shape_dtypes.pop()
1289  else:
1290    return dtypes.int64
1291
1292
1293def _merge_nrows(nrows, static_nrows, value, dtype, validate):
1294  """Merges `nrows` with `nrows(value)`.
1295
1296  Checks that `value` has the expected number of rows (`nrows`), and returns
1297  `nrows`.  If `validate` is true, then add validation ops that check that
1298  the `nrows` values match.
1299
1300  Args:
1301    nrows: scalar integer Tensor.
1302    static_nrows: tf.Dimension: static value of nrows, if known.
1303    value: Tensor or RaggedTensor or StructuredTensor
1304    dtype: dtype for `nrows`.
1305    validate: bool -- whether to add validation ops.
1306
1307  Returns:
1308    A tuple `(nrows, static_nrows)`.
1309  """
1310  static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
1311  if isinstance(value, ops.Tensor):
1312    value_nrows = array_ops.shape(value, out_type=dtype)[0]
1313  else:
1314    value_nrows = value.nrows()
1315  if nrows is None:
1316    nrows = value_nrows
1317  elif (static_value_nrows.value is not None and
1318        static_nrows.value is not None):
1319    if not static_value_nrows.is_compatible_with(static_nrows):
1320      raise ValueError('fields have incompatible nrows')
1321    nrows = value_nrows  # No need to add an assertion op.
1322  elif validate:
1323    nrows = control_flow_ops.with_dependencies([
1324        check_ops.assert_equal(
1325            nrows, value_nrows, message='fields have incompatible nrows')
1326    ], nrows)
1327  return nrows, static_nrows.merge_with(static_value_nrows)
1328
1329
1330def _merge_row_partitions(row_partitions, value, rank, dtype, validate):
1331  """Merges `row_partitions` with `row_partitions(value)`."""
1332  if isinstance(value, ops.Tensor):
1333    value_row_partitions = _row_partitions_for_tensor(value, rank, dtype)
1334
1335  elif isinstance(value, ragged_tensor.RaggedTensor):
1336    value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype)
1337
1338  else:
1339    assert isinstance(value, StructuredTensor), type(value)
1340    value_row_partitions = value.row_partitions[:rank - 1]
1341
1342  assert len(value_row_partitions) == rank - 1
1343  if row_partitions is None:
1344    return tuple(value_row_partitions)
1345  else:
1346    return tuple([
1347        p1.merge_precomputed_encodings(p2, validate)
1348        for (p1, p2) in zip(row_partitions, value_row_partitions)
1349    ])
1350
1351
1352def _row_partitions_for_tensor(value, rank, dtype):
1353  """Returns the row partitions for a tf.Tensor."""
1354  shape = array_ops.shape(value, out_type=dtype)
1355  return _row_partitions_for_uniform_shape(shape, rank)
1356
1357
1358def _row_partitions_for_ragged_tensor(value, rank, dtype):
1359  """Returns the row partitions for a tf.RaggedTensor."""
1360  assert rank > 1
1361  value_row_partitions = value._nested_row_partitions[:rank - 1]  # pylint: disable=protected-access
1362  if len(value_row_partitions) < (rank - 1):
1363    value_row_partitions += _row_partitions_for_tensor(
1364        value.flat_values, rank - len(value_row_partitions), dtype)
1365  assert len(value_row_partitions) == rank - 1
1366  return value_row_partitions
1367
1368
1369def _row_partitions_for_uniform_shape(shape, rank):
1370  """Returns row partitions for the given shape Tensor.
1371
1372  Args:
1373    shape: A vector describing a uniform shape.
1374    rank: The number of dimensions to generate row partitions for
1375
1376  Returns:
1377    A list of (rank-1) `RowPartition`s with uniform row length.
1378  """
1379  shape_cumprod = math_ops.cumprod(shape[:rank])
1380  # pylint: disable=g-complex-comprehension
1381  return tuple([
1382      RowPartition.from_uniform_row_length(
1383          uniform_row_length=shape[i + 1],
1384          nvals=shape_cumprod[i + 1],
1385          nrows=shape_cumprod[i]) for i in range(rank - 1)
1386  ])
1387
1388
1389def _pyval_field_major_to_node_major(keys, values, depth):
1390  """Regroup each field (k, v) from dict-of-list to list-of-dict.
1391
1392  Given a "field-major" encoding of the StructuredTensor (which maps each key to
1393  a single nested list containing the values for all structs), return a
1394  corresponding "node-major" encoding, consisting of a nested list of dicts.
1395
1396  Args:
1397    keys: The field names (list of string).  Must not be empty.
1398    values: The field values (list of python values).  Must have the same length
1399      as `keys`.
1400    depth: The list depth at which dictionaries should be created.
1401
1402  Returns:
1403    A nested list of dict, with depth `depth`.
1404  """
1405  assert keys
1406  if depth == 0:
1407    return dict(zip(keys, values))
1408  nvals = len(values[0])
1409  assert all(nvals == len(values[i]) for i in range(1, len(values)))
1410  return [
1411      _pyval_field_major_to_node_major(keys, value_slice, depth - 1)
1412      for value_slice in zip(*values)
1413  ]
1414
1415
1416def _empty_dict_pylist_from_row_partitions(row_partitions, nrows):
1417  """Returns a python list of empty dicts from the given row partitions.
1418
1419  Args:
1420    row_partitions: The row-partitions describing the ragged shape of the
1421      result.
1422    nrows: The number of rows in the outermost row-partition.  (Or if
1423      `len(row_partitions)==0`, then the number of empty dicts to return.)
1424
1425  Returns:
1426    A nested python list whose leaves (if any) are empty python dicts.
1427  """
1428  if not row_partitions:
1429    return [{} for _ in range(nrows)]
1430  else:
1431    values = _empty_dict_pylist_from_row_partitions(
1432        row_partitions[1:], row_partitions[0].row_splits()[-1])
1433    splits = row_partitions[0].row_splits()
1434    return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
1435
1436
1437def _pyval_find_struct_keys_and_depth(pyval, keys):
1438  """Finds the keys & depth of nested dictionaries in `pyval`.
1439
1440  Args:
1441    pyval: A nested structure of lists, tuples, and dictionaries.
1442    keys: (output parameter) A set, which will be updated with any keys that are
1443      found in the nested dictionaries.
1444
1445  Returns:
1446    The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does
1447    not contain any dictionaries.
1448  Raises:
1449    ValueError: If dictionaries have inconsistent depth.
1450  """
1451  if isinstance(pyval, dict):
1452    keys.update(pyval.keys())
1453    return 0
1454  elif isinstance(pyval, (list, tuple)):
1455    depth = None
1456    for child in pyval:
1457      child_depth = _pyval_find_struct_keys_and_depth(child, keys)
1458      if child_depth is not None:
1459        if depth is None:
1460          depth = child_depth + 1
1461        elif depth != child_depth + 1:
1462          raise ValueError('Inconsistent depth of dictionaries')
1463    return depth
1464  else:
1465    return None
1466
1467
1468def _pyval_update_fields(pyval, fields, depth):
1469  """Append the field values from `pyval` to `fields`.
1470
1471  Args:
1472    pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s)
1473      should be appended to `fields`.
1474    fields: A dictionary mapping string keys to field values.  Field values
1475      extracted from `pyval` are appended to this dictionary's values.
1476    depth: The depth at which `pyval` should be appended to the field values.
1477  """
1478  if not isinstance(pyval, (dict, list, tuple)):
1479    raise ValueError('Expected dict or nested list/tuple of dict')
1480
1481  for (key, target) in fields.items():
1482    for _ in range(1, depth):
1483      target = target[-1]
1484    target.append(pyval[key] if isinstance(pyval, dict) else [])
1485
1486  if isinstance(pyval, (list, tuple)):
1487    for child in pyval:
1488      _pyval_update_fields(child, fields, depth + 1)
1489
1490
1491def _pyval_empty_list_depth(pyval):
1492  """Find the max depth for nested empty lists.
1493
1494  Args:
1495    pyval: A nested python list.
1496
1497  Returns:
1498    The maximum depth of empty lists in `pyval`, or None if `pyval` contains
1499    anything other than nested empty lists.
1500  """
1501  if isinstance(pyval, list):
1502    if not pyval:
1503      return 1
1504    depths = [_pyval_empty_list_depth(v) for v in pyval]
1505    if any(depth is None for depth in depths):
1506      return None
1507    else:
1508      return max(depths) + 1
1509  else:
1510    return None
1511
1512
1513def _replace_row_partitions(value, new_partitions):
1514  """Updates `value` to use `new_partitions` as its (outer) row partitions.
1515
1516  This is used to ensure that all fields in a `StructuredTensor` use identical
1517  `RowPartition` objects for the shared dimensions.  In particular,
1518  `StructuredTensor.from_fields` first merges all of the row partitions from
1519  any fields, and then replaces the outer row partitions of all fields with
1520  the merged row partitions (using this function).
1521
1522  Args:
1523    value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`.
1524    new_partitions: A list of row-partitions that should be used by `value`.
1525      Must be equivalent to `value`'s current row partitions.
1526
1527  Returns:
1528    A value that is equivalent to `value`, where outer row partitions have been
1529    replaced by `new_partitions`.
1530  """
1531  if isinstance(value, ops.Tensor) or not new_partitions:
1532    return value
1533
1534  elif isinstance(value, ragged_tensor.RaggedTensor):
1535    return ragged_tensor.RaggedTensor._from_row_partition(  # pylint: disable=protected-access
1536        values=_replace_row_partitions(value.values, new_partitions[1:]),
1537        row_partition=new_partitions[0])
1538
1539  else:
1540    assert isinstance(value, StructuredTensor)
1541    new_fields = dict((k, _replace_row_partitions(v, new_partitions))
1542                      for (k, v) in value._fields.items())
1543    return StructuredTensor(
1544        fields=new_fields,
1545        shape=value.shape,
1546        nrows=value.nrows(),
1547        row_partitions=new_partitions +
1548        value.row_partitions[len(new_partitions):],
1549        internal=_structured_tensor_factory_key)
1550
1551
1552def _partition_outer_dimension(value, row_partition):
1553  """Partitions the outer dimension of `value` using `row_partitions`.
1554
1555  Examples:
1556
1557    >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1558    >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition)
1559    <tf.RaggedTensor [[1, 2], [], [3]]>
1560
1561    >>> struct_value = StructuredTensor.from_pyval(
1562    ...     [{'x': 1}, {'x': 2}, {'x': 3}])
1563    >>> _partition_outer_dimension(struct_value, partition)
1564    <StructuredTensor(
1565      fields={
1566        "x": <tf.RaggedTensor [[1, 2], [], [3]]>},
1567      shape=(3, None))>
1568
1569  Args:
1570    value: Tensor, RaggedTensor, or StructuredTensor
1571    row_partition: RowPartition
1572
1573  Returns:
1574    A value with the same type as `value`, where
1575    `result.rank = value.rank + 1`.
1576  """
1577  is_ragged = row_partition.uniform_row_length() is None
1578  if isinstance(value, ops.Tensor) and not is_ragged:
1579    new_shape = array_ops.concat(
1580        [[row_partition.nrows(),
1581          row_partition.uniform_row_length()],
1582         array_ops.shape(value, out_type=row_partition.dtype)[1:]],
1583        axis=0)
1584    return array_ops.reshape(value, new_shape)
1585  elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1586    return ragged_tensor.RaggedTensor._from_row_partition(  # pylint: disable=protected-access
1587        value, row_partition)
1588  else:
1589    assert isinstance(value, StructuredTensor)
1590    nrows = row_partition.static_nrows
1591    ncols = row_partition.static_uniform_row_length
1592    shape = tensor_shape.TensorShape([nrows,
1593                                      ncols]).concatenate(value.shape[1:])
1594    fields = dict((k, _partition_outer_dimension(v, row_partition))
1595                  for (k, v) in value._fields.items())
1596    return StructuredTensor(
1597        fields,
1598        shape,
1599        row_partition.nrows(), (row_partition,) + value.row_partitions,
1600        internal=_structured_tensor_factory_key)
1601
1602
1603def _merge_dims(value, outer_axis, inner_axis):
1604  """Merges `outer_axis...inner_axis` of `value` into a single dimension."""
1605  assert outer_axis < inner_axis
1606  if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1607    return ragged_tensor.merge_dims(value, outer_axis, inner_axis)
1608  else:
1609    assert isinstance(value, StructuredTensor)
1610
1611    # Build the new fields.
1612    fields = dict((k, _merge_dims(v, outer_axis, inner_axis))
1613                  for (k, v) in value._fields.items())
1614
1615    # Build the new shape.
1616    value_shape = value.shape
1617    shape = (
1618        value_shape[:outer_axis] +
1619        [value_shape[outer_axis:inner_axis].num_elements()] +
1620        value_shape[inner_axis + 1:])
1621
1622    # Build the new row_partitions & nrows
1623    if outer_axis == 0:
1624      if inner_axis == value.shape.rank - 1:
1625        partitions = ()
1626        nrows = value.row_partitions[-1].nvals()
1627      else:
1628        partitions = value.row_partitions[inner_axis:]
1629        nrows = partitions[0].nrows()
1630    else:
1631      # Use tf.gather to merge row_splits from the merged row partitions.
1632      merged_splits = value.row_partitions[outer_axis - 1].row_splits()
1633      for dim in range(outer_axis, inner_axis):
1634        merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(),
1635                                         merged_splits)
1636
1637      partitions = (
1638          value.row_partitions[:outer_axis - 1] +
1639          (RowPartition.from_row_splits(merged_splits),) +
1640          value.row_partitions[inner_axis:])
1641      nrows = partitions[0].nrows()
1642
1643    return StructuredTensor(
1644        fields,
1645        shape,
1646        nrows,
1647        partitions,
1648        internal=_structured_tensor_factory_key)
1649
1650
1651_structured_tensor_factory_key = object()  # unique private object
1652
1653
1654def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]:
1655  """FieldName can be given also as string, this normalizes it to a tuple."""
1656  if isinstance(name, str):
1657    return (name,)
1658  if isinstance(name, list):
1659    return tuple(name)
1660  assert isinstance(name, tuple)
1661  return name
1662
1663
1664def _dicts_to_zeros(pyval):
1665  """Replaces dictionaries zeros in a pylist."""
1666  if isinstance(pyval, dict):
1667    return 0
1668  return [_dicts_to_zeros(x) for x in pyval]
1669
1670
1671def _merge_dims_generic(source, outer, inner):
1672  """Merges outer_axis...inner_axis into a single dimension.
1673
1674  If outer == inner, this is a NOOP. If inner < outer, then this fials.
1675  If inner >= source.shape.rank, then the behavior is undefined.
1676
1677  Args:
1678    source: a tensor, ragged tensor, or structured tensor.
1679    outer: a python int, indicating the first dimension to compress (must be
1680      nonnegative).
1681    inner: a python int, indicating the first dimension to keep (of the tail)
1682      (must be nonnegative).
1683
1684  Returns:
1685    source with outer_axis...inner_axis merged into a single dimension.
1686
1687  """
1688  if isinstance(source, StructuredTensor):
1689    return source.merge_dims(outer, inner)
1690  else:
1691    return ragged_tensor.merge_dims(source, outer, inner)
1692