1# Lint as python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Structured Tensors.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import logging 23import re 24from typing import Callable, Dict, List, Sequence, Tuple, Union 25 26import numpy as np 27 28from tensorflow.python.framework import composite_tensor 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops.ragged import ragged_factory_ops 40from tensorflow.python.ops.ragged import ragged_tensor 41from tensorflow.python.ops.ragged import row_partition as row_partition_lib 42from tensorflow.python.ops.ragged.row_partition import RowPartition 43from tensorflow.python.util import compat 44from tensorflow.python.util import nest 45 46 47class StructuredTensor(composite_tensor.CompositeTensor): 48 """A multidimensional collection of structures with the same schema. 49 50 A **`StructuredTensor`** is a multi-dimensional collection of ***structures*** 51 with the same ***schema***, where: 52 53 * A ***schema*** is a collection of fields, each of which has a name and type. 54 * A ***structure*** maps each field in the schema to a tensor value (which 55 could be a nested StructuredTensor). 56 57 As an important special case, a 1D `StructuredTensor` encodes a 2D table, 58 where columns are heterogeneous `Tensor`s, and rows are the aligned elements 59 in each of those `Tensor`s. 60 61 Internally, StructuredTensors use a "field-major" encoding: for each leaf 62 field, there is a single tensor that stores the value of that field for all 63 structures in the `StructuredTensor`. 64 65 ### Examples 66 67 >>> # A scalar StructuredTensor describing a single person. 68 >>> s1 = StructuredTensor.from_pyval( 69 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}) 70 >>> s1.shape 71 TensorShape([]) 72 >>> s1["age"] 73 <tf.Tensor: shape=(), dtype=int32, numpy=82> 74 75 >>> # A vector StructuredTensor describing three people. 76 >>> s2 = StructuredTensor.from_pyval([ 77 ... {"age": 12, "nicknames": ["Josaphine"]}, 78 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}, 79 ... {"age": 42, "nicknames": ["Elmo"]}]) 80 >>> s2.shape 81 TensorShape([3]) 82 >>> s2[0]["age"] 83 <tf.Tensor: shape=(), dtype=int32, numpy=12> 84 85 86 ### Field Paths 87 88 A *field path* is a tuple of field names, specifying the path to a nested 89 field. 90 """ 91 92 #============================================================================= 93 # Common Types 94 #============================================================================= 95 # pylint: disable=invalid-name 96 # Field names work as key, and they can be a sequence to refer to the 97 # sub-levels (embedded) StructuredTensor's. 98 FieldName = Union[str, Sequence[str]] 99 100 # Each field may contain one of the following types of Tensors. 101 FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor'] 102 103 # Function that takes a FieldValue as input and returns the transformed 104 # FieldValue. 105 FieldFn = Callable[[FieldValue], FieldValue] 106 107 # pylint: enable=invalid-name 108 109 #============================================================================= 110 # Constructor & Factory Methods 111 #============================================================================= 112 113 def __init__(self, fields, shape, nrows, row_partitions, internal=False): 114 """Private constructor -- use factory methods to create StructuredTensors. 115 116 This constructor builds a `StructuredTensor` from the given attributes, 117 performing minimal validation. 118 119 Args: 120 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 121 `StructuredTensor`. (This dict is not copied, so the caller must ensure 122 that it does not get mutated via leaked references.) 123 shape: `tf.TensorShape` with statically known rank. 124 nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`. 125 row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`. 126 internal: Private key value, required to ensure that this private 127 constructor is *only* called from the factory methods. 128 """ 129 if internal is not _structured_tensor_factory_key: 130 raise ValueError('StructuredTensor constructor is private; please use ' 131 'one of the factory methods instead (e.g., ' 132 'StructuredTensor.from_fields())') 133 assert isinstance(fields, dict), fields 134 assert isinstance(shape, tensor_shape.TensorShape), shape 135 assert nrows is None or isinstance(nrows, ops.Tensor), nrows 136 assert isinstance(row_partitions, tuple), row_partitions 137 self._fields = fields 138 self._shape = shape 139 self._nrows = nrows 140 self._row_partitions = row_partitions 141 142 @classmethod 143 def from_fields(cls, 144 fields, 145 shape=(), 146 nrows=None, 147 row_partitions=None, 148 validate=False): 149 """Creates a `StructuredTensor` from a dictionary of fields. 150 151 Args: 152 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 153 `StructuredTensor`, providing the values for individual fields in each 154 structure. If `shape.rank > 0`, then every tensor in `fields` must have 155 the same shape in the first `shape.rank` dimensions; and that shape must 156 be compatible with `shape`; and `result[i1...iN][key] = 157 fields[key][i1...iN]` (where `N==shape.rank`). 158 shape: A `TensorShape`: static information about the shape of the 159 `StructuredTensor`. Must have a known `rank`. Defaults to scalar shape 160 (i.e. `rank=0`). 161 nrows: scalar integer tensor containing the number of rows in this 162 `StructuredTensor`. Should only be specified if `shape.rank > 0`. 163 Default value is inferred from the `fields` values. If `fields` is 164 empty, then this must be specified. 165 row_partitions: A list of `RowPartition`s describing the (possibly ragged) 166 shape of this `StructuredTensor`. Should only be specified if 167 `shape.rank > 1`. Default value is inferred from the `fields` values. 168 If `fields` is empty, then this must be specified. 169 validate: If true, then add runtime validation ops that check that the 170 field values all have compatible shapes in the outer `shape.rank` 171 dimensions. 172 173 Returns: 174 A `StructuredTensor`. 175 176 Examples: 177 178 >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]}) 179 <StructuredTensor( 180 fields={ 181 "x": tf.Tensor(1, shape=(), dtype=int32), 182 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 183 shape=())> 184 185 >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]}, 186 ... shape=[2]) 187 <StructuredTensor( 188 fields={ 189 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32), 190 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)}, 191 shape=(2,))> 192 """ 193 shape = tensor_shape.as_shape(shape) 194 rank = shape.rank 195 if rank is None: 196 raise ValueError("StructuredTensor's shape must have known rank.") 197 if not isinstance(fields, dict): 198 raise TypeError('fields must be a dictionary, got %s' % 199 type(fields).__name__) 200 if rank < 2 and row_partitions: 201 raise ValueError('row_partitions must be None or [] if shape.rank<2') 202 if rank == 0 and nrows is not None: 203 raise ValueError('nrows must be None if shape.rank==0') 204 if row_partitions is not None: 205 row_partitions = tuple(row_partitions) 206 if len(row_partitions) != max(0, rank - 1): 207 raise ValueError('len(row_partitions) must be shape.rank-1') 208 elif rank < 2: 209 row_partitions = () 210 211 fields = dict(fields) # Make a private copy. 212 with ops.name_scope(None, 'StructuredTensor', fields.values()): 213 214 # Validate keys and convert field values to tensors. 215 for key, value in fields.items(): 216 if not isinstance(key, str): 217 raise TypeError('Unexpected type for key in `fields`: %r' % key) 218 if not _FIELD_NAME_RE.match(key): 219 raise ValueError('Field name %r is not currently allowed.' % key) 220 fields[key] = _convert_to_structured_field_value(value) 221 222 # Determine dtype for row_partitions and nrows. 223 shape_dtype = _find_shape_dtype(fields, nrows, row_partitions) 224 if nrows is not None: 225 nrows = ops.convert_to_tensor(nrows, shape_dtype) 226 227 # Get the static TensorShape for this StructuredTensor. 228 if rank > 0: 229 for key, value in fields.items(): 230 if not shape.is_compatible_with(value.shape[:rank]): 231 raise ValueError('Field {} has shape {}, which is incompatible ' 232 'with the shape that was specified or inferred ' 233 'from other fields: {}'.format( 234 key, value.shape[:rank], shape)) 235 shape = shape.merge_with(value.shape[:rank]) 236 237 if rank == 1: 238 # Find a consistent value for `nrows`. 239 static_nrows = tensor_shape.dimension_at_index(shape, 0) 240 for value in fields.values(): 241 nrows, static_nrows = _merge_nrows(nrows, static_nrows, value, 242 shape_dtype, validate) 243 if nrows is None: 244 if static_nrows.value is None: 245 raise ValueError('nrows must be specified if rank==1 ' 246 'and `fields` is empty.') 247 else: 248 nrows = constant_op.constant(static_nrows.value, shape_dtype) 249 250 if rank > 1: 251 # Find a consistent list of RowPartitions. 252 for value in fields.values(): 253 row_partitions = _merge_row_partitions(row_partitions, value, rank, 254 shape_dtype, validate) 255 if row_partitions is None: 256 if not shape.is_fully_defined(): 257 raise ValueError('row_partitions must be specified if rank>1 ' 258 'and `fields` is empty.') 259 else: 260 row_partitions = _row_partitions_for_uniform_shape( 261 np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype), 262 shape.rank) 263 assert len(row_partitions) == rank - 1 264 nrows = row_partitions[0].nrows() 265 # Update all field values to use the shared RowPartition objects. 266 fields = dict([(k, _replace_row_partitions(v, row_partitions)) 267 for (k, v) in fields.items()]) 268 269 return cls( 270 fields, 271 shape, 272 nrows, 273 row_partitions, 274 internal=_structured_tensor_factory_key) 275 276 def with_updates( 277 self, 278 updates: Dict[FieldName, Union[FieldValue, FieldFn, None]], 279 validate: bool = False 280 ) -> 'StructuredTensor': 281 """Creates a new `StructuredTensor` with the updated fields. 282 283 If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being 284 updated and `v` the new value, then: 285 286 ``` 287 result[k] = v # If (k, v) is in updates and v is a FieldValue 288 result[k] = f(self[k]) # If (k, f) is in updates and f is a FieldFn 289 result[k] = self[k] # If k is in self.field_names but not in updates 290 ``` 291 292 If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each 293 FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is, 294 prefixed with the same shape as the `StructuredTensor`. Then the resulting 295 `StructuredTensor` will have: 296 297 ``` 298 result[i1...iN][k] = v[i1...iN] # (k, v) in updates 299 result[i1...iN][k] = f(self.field_value(k))[i1...iN] # (k, f) in updates 300 result[i1...iN][k] = self[i1...iN][k] # k not in updates 301 ``` 302 303 Note that `result.shape` is always equal to `self.shape` (but the shapes 304 of nested StructuredTensors may be changed if they are updated with new 305 values). 306 307 Args: 308 updates: A dictionary mapping `FieldName` to either a `FieldValue` to be 309 used to update, or a `FieldFn` that will transform the value for the 310 given `FieldName`. `FieldName` can be a string for a direct field, or a 311 sequence of strings to refer to a nested sub-field. `FieldFn` is a 312 function that takes a `FieldValue` as input and should return a 313 `FieldValue`. All other fields are copied over to the new 314 `StructuredTensor`. New `FieldName` can be given (to add new fields), 315 but only to existing `StructuredTensor`, it won't automatically create 316 new nested structures -- but one can create a whole `StructureTensor` 317 sub-structure and set that into an existing structure. If the new value 318 is set to `None`, it is removed. 319 validate: If true, then add runtime validation ops that check that the 320 field values all have compatible shapes in the outer `shape.rank` 321 dimensions. 322 323 Returns: 324 A `StructuredTensor`. 325 326 Raises: 327 `ValueError`: If the any of the `FieldName` keys points to non-existent 328 sub-structures, if parent and child nodes are updated, if shapes 329 change, if a delete update is given for a non-existant field, or if a 330 `FieldFn` transforming function is given for a `FieldName` that doesn't 331 yet exist. 332 333 Examples: 334 335 >>> shoes_us = StructuredTensor.from_pyval([ 336 ... {"age": 12, "nicknames": ["Josaphine"], 337 ... "shoes": {"sizes": [8.0, 7.5, 7.5]}}, 338 ... {"age": 82, "nicknames": ["Bob", "Bobby"], 339 ... "shoes": {"sizes": [11.0, 11.5, 12.0]}}, 340 ... {"age": 42, "nicknames": ["Elmo"], 341 ... "shoes": {"sizes": [9.0, 9.5, 10.0]}}]) 342 >>> def us_to_europe(t): 343 ... return tf.round(t * 2.54 + 17.0) # Rough approximation. 344 >>> shoe_sizes_key = ("shoes", "sizes") 345 >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe}) 346 >>> shoes_eu.field_value(shoe_sizes_key) 347 <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0], 348 [40.0, 41.0, 42.0]]> 349 """ 350 updates_items = [(_normalize_field_name_to_tuple(name), value) 351 for name, value in updates.items()] 352 353 # Sort by keys and check for updates of both parent and child nodes. 354 updates_items = sorted(updates_items) 355 for i in range(1, len(updates_items)): 356 # Parent of a node would precede node in the sorted order. 357 name = updates_items[i][0] # item[0] is the name, item[1] is the value. 358 prev_name = updates_items[i - 1][0] 359 if name[:len(prev_name)] == prev_name: 360 raise ValueError( 361 '`StructuredTensor.with_updates` does not allow both parent and ' 362 'child nodes to be updated: parent={}, child={}. If needed you can ' 363 'update child nodes in the parent update value.'.format( 364 prev_name, name)) 365 return self._with_updates_impl((), updates_items, validate) 366 367 def _with_updates_impl( 368 self, 369 error_prefix: Tuple[str], 370 updates: List[Tuple[FieldName, Union[FieldValue, FieldFn]]], 371 validate: bool) -> 'StructuredTensor': 372 """Recursive part of `with_updates` implementation.""" 373 # Get current fields. 374 new_fields = dict(self._fields) 375 376 # Convert field name to string with full path for error messages. 377 def name_fullpath(name: Sequence[str]) -> str: 378 return str(error_prefix + (name,)) 379 380 # Apply value if a function or the value itself. 381 def apply_value(name: str, value: Union['FieldValue', 382 'FieldFn']) -> 'FieldValue': 383 if callable(value): 384 # `value` is actually a transforming function. 385 if name not in new_fields: 386 raise ValueError( 387 '`StructuredTensor.with_updates` cannot update the field {} ' 388 'because a transforming function was given, but that field ' 389 'does not already exist.'.format(name_fullpath(name))) 390 value = value(new_fields[name]) 391 return value 392 393 # Merge updates. 394 for name, value in updates: 395 if not name or not name[0]: 396 raise ValueError( 397 '`StructuredTensor.with_updates` does not allow empty names ' 398 '{}.'.format(name_fullpath(name))) 399 400 if len(name) == 1: 401 name = name[0] 402 if value is None: 403 if name not in new_fields: 404 raise ValueError( 405 '`StructuredTensor.with_updates` cannot delete field ' 406 '{} because it is not present.'.format(name_fullpath(name))) 407 new_fields.pop(name) 408 else: 409 new_fields[name] = apply_value(name, value) 410 else: 411 # Recursive 412 prefix = name[0] 413 suffix = name[1:] 414 if prefix not in new_fields: 415 raise ValueError( 416 '`StructuredTensor.with_updates` cannot create new sub-field ' 417 '{} if parent field {} is not set.'.format( 418 error_prefix + tuple(name), name_fullpath(prefix))) 419 current_value = new_fields[prefix] 420 if not isinstance(current_value, StructuredTensor): 421 raise ValueError( 422 '`StructuredTensor.with_updates` cannot create new sub-field ' 423 '{} if parent structure {} is not a `StructuredTensor` that ' 424 'can contain sub-structures -- it is a `{}`.'.format( 425 error_prefix + tuple(name), name_fullpath(prefix), 426 type(current_value))) 427 one_update = [(suffix, value)] 428 429 # Accessing protected member in recursion. 430 # FutureWork: optimize by aggregating the recursions, instead of 431 # calling one at a time. 432 # pylint: disable=protected-access 433 value = current_value._with_updates_impl(error_prefix + (prefix,), 434 one_update, validate) 435 # pylint: enable=protected-access 436 new_fields[prefix] = value 437 438 # TODO(edloper): When validate=True, only validate the modified fields. 439 try: 440 return StructuredTensor.from_fields( 441 new_fields, 442 shape=self.shape, 443 row_partitions=self._row_partitions, 444 nrows=self._nrows, 445 validate=validate) 446 447 except ValueError as e: 448 msg = '`StructuredTensor.with_updates` failed' 449 if error_prefix: 450 msg = '{} for field {}'.format(msg, error_prefix) 451 raise ValueError('{}: {}'.format(msg, e)) 452 453 def _promote_helper(self, source_path, new_parent_path): 454 """Creates a promoted field without adding it to the structure. 455 456 Args: 457 source_path: the source path in the structured tensor. 458 new_parent_path: the new parent path. Must be a prefix of source_path. 459 460 Returns: 461 a composite tensor of source_path promoted. 462 Raises: 463 ValueError: if the shape of the field is unknown and the right strategy 464 cannot be determined. 465 """ 466 current_field = self.field_value(source_path) 467 new_parent_rank = self.field_value(new_parent_path).rank 468 parent_rank = self.field_value(source_path[:-1]).rank 469 if new_parent_rank == parent_rank: 470 return current_field 471 current_field_rank = current_field.shape.rank 472 if current_field_rank is None: 473 raise ValueError('Cannot determine if dimensions should be merged.') 474 inner_dim = min(parent_rank, current_field_rank - 1) 475 if inner_dim <= new_parent_rank: 476 return current_field 477 return _merge_dims_generic(current_field, new_parent_rank, inner_dim) 478 479 def promote(self, source_path, new_name): 480 """Promotes a field, merging dimensions between grandparent and parent. 481 482 >>> d = [ 483 ... {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]}, 484 ... {'docs': [{'tokens':[7]}]}] 485 >>> st = StructuredTensor.from_pyval(d) 486 >>> st2 =st.promote(('docs','tokens'), 'docs_tokens') 487 >>> st2[0]['docs_tokens'] 488 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)> 489 >>> st2[1]['docs_tokens'] 490 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)> 491 492 Args: 493 source_path: the path of the field or substructure to promote; must have 494 length at least 2. 495 new_name: the name of the new field (must be a string). 496 497 Returns: 498 a modified structured tensor with the new field as a child of the 499 grandparent of the source_path. 500 501 Raises: 502 ValueError: if source_path is not a list or a tuple or has a length 503 less than two, or new_name is not a string, or the rank 504 of source_path is unknown and it is needed. 505 """ 506 if not isinstance(new_name, str): 507 raise ValueError('new_name is not a string') 508 if not isinstance(source_path, (list, tuple)): 509 raise ValueError('source_path must be a list or tuple') 510 511 if len(source_path) < 2: 512 raise ValueError('source_path must have length at least two') 513 514 grandparent_path = source_path[:-2] 515 new_field = self._promote_helper(source_path, grandparent_path) 516 new_path = grandparent_path + (new_name,) 517 return self.with_updates({new_path: new_field}) 518 519 #============================================================================= 520 # Properties 521 #============================================================================= 522 523 @property 524 def rank(self): 525 """The rank of this StructuredTensor. Guaranteed not to be `None`.""" 526 return self._shape.rank 527 528 @property 529 def shape(self): 530 """The static shape of this StructuredTensor. 531 532 The returned `TensorShape` is guaranteed to have a known rank, but the 533 individual dimension sizes may be unknown. 534 535 Returns: 536 `tf.TensorShape` 537 """ 538 return self._shape 539 540 # TODO(edloper): Make this a func instead of a property? Or make nrows 541 # a property instead of a func? Seems like these should be consistent. 542 @property 543 def row_partitions(self): 544 """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`. 545 546 When `self.rank <= 1`, this tuple will be empty. 547 548 When `self.rank > 1`, these `RowPartitions` define the shape of the 549 `StructuredTensor` by describing how a flat (1D) list of structures can be 550 repeatedly partitioned to form a higher-dimensional object. In particular, 551 the flat list is first partitioned into sublists using `row_partitions[-1]`, 552 and then those sublists are further partitioned using `row_partitions[-2]`, 553 etc. The following examples show the row partitions used to describe 554 several different `StructuredTensor`, each of which contains 8 copies of 555 the same structure (`x`): 556 557 >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']} # shape = [] (scalar) 558 559 >>> s1 = [[x, x, x, x], [x, x, x, x]] # shape = [2, 4] 560 >>> StructuredTensor.from_pyval(s1).row_partitions 561 (tf.RowPartition(row_splits=tf.Tensor([0 4 8], shape=(3,), 562 dtype=int64)),) 563 564 >>> s2 = [[x, x], [x, x], [x, x], [x, x]] # shape = [4, 2] 565 >>> StructuredTensor.from_pyval(s2).row_partitions 566 (tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,), 567 dtype=int64)),) 568 569 >>> s3 = [[x, x, x], [], [x, x, x, x], [x]] # shape = [2, None] 570 >>> StructuredTensor.from_pyval(s3).row_partitions 571 (tf.RowPartition(row_splits=tf.Tensor([0 3 3 7 8], shape=(5,), 572 dtype=int64)),) 573 574 >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]] # shape = [2, 2, 2] 575 >>> StructuredTensor.from_pyval(s4).row_partitions 576 (tf.RowPartition(row_splits=tf.Tensor([0 2 4], shape=(3,), dtype=int64)), 577 tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,), 578 dtype=int64))) 579 580 581 >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]] # shape = [3, None, None] 582 >>> StructuredTensor.from_pyval(s5).row_partitions 583 (tf.RowPartition(row_splits=tf.Tensor([0 2 3 5], shape=(4,), dtype=int64)), 584 tf.RowPartition(row_splits=tf.Tensor([0 2 3 5 7 8], shape=(6,), 585 dtype=int64))) 586 587 Note that shapes for nested fields (such as `x['b']` in the above example) 588 are not considered part of the shape of a `StructuredTensor`, and are not 589 included in `row_partitions`. 590 591 If this `StructuredTensor` has a ragged shape (i.e., if any of the 592 `row_partitions` is not uniform in size), then all fields will be encoded 593 as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s 594 used to define their outermost `self.rank` dimensions. 595 596 Returns: 597 A `tuple` of `RowPartition` objects with length `self.rank - 1` 598 (or `0` if `self.rank < 2`) 599 600 """ 601 return self._row_partitions 602 603 def nrows(self): 604 """The number of rows in this StructuredTensor (if rank>0). 605 606 This means the length of the outer-most dimension of the StructuredTensor. 607 608 Notice that if `self.rank > 1`, then this equals the number of rows 609 of the first row partition. That is, 610 `self.nrows() == self.row_partitions[0].nrows()`. 611 612 Otherwise `self.nrows()` will be the first dimension of the field values. 613 614 Returns: 615 A scalar integer `Tensor` (or `None` if `self.rank == 0`). 616 """ 617 return self._nrows 618 619 def _is_eager(self): 620 """True if all fields are composed of eager tensors.""" 621 tensors = nest.flatten(self, expand_composites=True) 622 return all(isinstance(t, ops.EagerTensor) for t in tensors) 623 624 #============================================================================= 625 # Encoding 626 #============================================================================= 627 628 def field_names(self): 629 """Returns the string field names for this `StructuredTensor`.""" 630 return tuple(self._fields.keys()) 631 632 def field_value(self, field_name): 633 """Returns the tensor value for the specified field or path. 634 635 If `field_name` is a `string`, then it names a field directly owned by this 636 `StructuredTensor`. If this `StructuredTensor` has shape `[D1...DN]`, then 637 the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice 638 `result[d1...dN]` contains the field value for the structure at 639 `self[d1...dN]`. 640 641 If `field_name` is a `tuple` of `string`, then it specifies a path to a 642 field owned by nested `StructuredTensor`. In particular, 643 `struct.field_value((f1, f2, ..., fN))` is equivalent to 644 `struct.field_value(f1).field_value(f2)....field_value(fN)` 645 646 Args: 647 field_name: `string` or `tuple` of `string`: The field whose values should 648 be returned. 649 650 Returns: 651 `Tensor`, `StructuredTensor`, or `RaggedTensor`. 652 653 Raises: 654 KeyError: If the given field_name is not found. 655 """ 656 if isinstance(field_name, (list, tuple)): 657 value = self 658 for f in field_name: 659 if not isinstance(value, StructuredTensor): 660 raise KeyError('Field path {} not found in {}'.format( 661 field_name, self)) 662 value = value.field_value(f) 663 return value 664 return self._fields[field_name] 665 666 #============================================================================= 667 # Operators 668 #============================================================================= 669 670 # TODO(edloper): Add support for ellipsis and/or newaxis? 671 def __getitem__(self, key): 672 """Returns the specified piece of this StructuredTensor. 673 674 * If `struct_tensor` is scalar (i.e., a single structure), then 675 `struct_tensor[f]` returns the value of field `f` (where `f` must be a 676 string). 677 678 * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional 679 tensor of structures), `struct_tensor[i]` selects an element or slice of 680 the tensor using standard Python semantics (e.g., negative values index 681 from the end). `i` may have any of the following types: 682 683 * `int` constant 684 * `string` constant 685 * scalar integer `Tensor` 686 * `slice` containing integer constants and/or scalar integer 687 `Tensor`s 688 689 #### Multidimensional indexing 690 691 `StructuredTensor` supports multidimensional indexing. I.e., `key` may be a 692 `tuple` of values, indexing or slicing multiple dimensions at once. For 693 example, if `people` is a vector of structures, each of which has a vector- 694 valued `names` field, then `people[3, 'names', 0]` is equivalent to 695 `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly 696 ragged) matrix of names, with shape `[num_people, num_names_per_person]`. 697 698 Args: 699 key: Indicates which piece of the StructuredTensor to return. 700 701 Returns: 702 A `Tensor`, `StructuredTensor`, or `RaggedTensor`. 703 """ 704 if isinstance(key, list): 705 key = tuple(key) 706 elif not isinstance(key, tuple): 707 key = (key,) 708 if not key: 709 return self 710 711 if self._shape.rank == 0: 712 return self._scalar_getitem(key) 713 else: 714 return self._tensor_getitem(key) 715 716 def _scalar_getitem(self, key): 717 if (isinstance(key[0], slice) and key[0].start is None and 718 key[0].stop is None and key[0].step is None): 719 fields = dict((field_name, field_value.__getitem__(key[1:])) 720 for (field_name, field_value) in self._fields.items()) 721 return StructuredTensor.from_fields(fields, self._shape) 722 723 elif not isinstance(key[0], compat.bytes_or_text_types): 724 raise ValueError('Key for indexing a StructuredTensor must be a ' 725 "string or a full slice (':')") 726 727 return self._fields[key[0]].__getitem__(key[1:]) 728 729 def _tensor_getitem(self, key): 730 rank = self._shape.rank 731 if len(key) <= rank: 732 new_fields = dict((field_name, field_value.__getitem__(key)) 733 for (field_name, field_value) in self._fields.items()) 734 result_shape = self.shape.as_list() 735 for d, k in enumerate(key): 736 if isinstance(k, slice): 737 if not (k.start is None and k.stop is None and k.step is None): 738 # TODO(edloper): Better static shape analysis here. 739 result_shape[d] = None 740 elif isinstance(k, (int, ops.Tensor)): 741 result_shape[d] = -1 # mark for deletion 742 elif k is None: 743 raise ValueError('Slicing not supported for tf.newaxis') 744 else: 745 # Ellipsis, tf.newaxis: 746 raise ValueError('Slicing not supported for %r' % k) 747 result_shape = [d for d in result_shape if d != -1] 748 return StructuredTensor.from_fields(new_fields, result_shape) 749 750 else: 751 if not isinstance(key[rank], compat.bytes_or_text_types): 752 # TODO(edloper): Also support full slice here? 753 raise ValueError('Key for indexing a StructuredTensor must be a string') 754 return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:]) 755 756 def __repr__(self): 757 fields = sorted(self._fields.items()) 758 fields = ((k, str(v).replace('\n', '\n ')) for k, v in fields) 759 fields = ('"{}": {}'.format(k, v) for k, v in fields) 760 dict_repr = ',\n '.join(fields) 761 return ('<StructuredTensor(\n' 762 ' fields={\n' 763 ' %s},\n' 764 ' shape=%s)>' % (dict_repr, self._shape)) 765 766 #============================================================================= 767 # Conversion 768 #============================================================================= 769 770 def to_pyval(self): 771 """Returns this StructuredTensor as a nested Python dict or list of dicts. 772 773 Converts this `StructuredTensor` to a nested python value: 774 775 * `StructTensors` with `rank=0` are converted into a dictionary, with an 776 entry for each field. Field names are used as keys and field values are 777 converted to python values. In particular: 778 779 * Scalar Tensor fields are converted to simple values (such as 780 `int` or `float` or `string`) 781 * Non-scalar Tensor fields and RaggedTensor fields are converted to 782 nested lists of simple values. 783 * StructuredTensor fields are converted recursively using `to_pyval`. 784 785 * `StructTensors` with `rank>0` are converted to nested python `list`s, 786 containing one dictionary for each structure (where each structure's 787 dictionary is defined as described above). 788 789 Requires that all fields are Eager tensors. 790 791 >>> StructuredTensor.from_fields( 792 ... {'a': [1, 2, 3]}, [3]).to_pyval() 793 [{'a': 1}, {'a': 2}, {'a': 3}] 794 795 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 796 797 Returns: 798 A nested Python dict or list of dicts. 799 """ 800 if not self._is_eager(): 801 raise ValueError( 802 'StructuredTensor.to_pyval() is only supported in eager mode.') 803 804 # Convert each field value to a nested list. 805 result = {} 806 for (key, value) in self._fields.items(): 807 if isinstance(value, ops.EagerTensor): 808 value = value.numpy() 809 if isinstance(value, np.ndarray): 810 value = value.tolist() 811 elif isinstance(value, ragged_tensor.RaggedTensor): 812 value = value.to_list() 813 elif isinstance(value, StructuredTensor): 814 value = value.to_pyval() 815 # TODO(edloper): Throw an exception if value is an unexpected type. 816 result[key] = value 817 818 # If rank>0, then re-group each value from dict-of-list to list-of-dict. 819 if len(self._shape) > 0: # pylint: disable=g-explicit-length-test 820 if not result: # special-case for StructuredTensors w/ no fields. 821 return _empty_dict_pylist_from_row_partitions(self._row_partitions, 822 self._nrows) 823 return _pyval_field_major_to_node_major( 824 list(result.keys()), list(result.values()), self._shape.rank) 825 else: 826 return result 827 828 @classmethod 829 def from_pyval(cls, pyval, typespec=None): 830 """Constructs a StructuredTensor from a nested Python structure. 831 832 >>> StructuredTensor.from_pyval( 833 ... {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]}) 834 <StructuredTensor( 835 fields={ 836 "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32), 837 "b": <tf.RaggedTensor [[4, 5], [6, 7]]>}, 838 shape=())> 839 840 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 841 842 Args: 843 pyval: The nested Python structure that should be used to create the new 844 `StructuredTensor`. 845 typespec: A `StructuredTensorSpec` specifying the expected type for each 846 field. If not specified, then all nested dictionaries are turned into 847 StructuredTensors, and all nested lists are turned into Tensors (if 848 rank<2) or RaggedTensors (if rank>=2). 849 850 Returns: 851 A `StructuredTensor`. 852 """ 853 return cls._from_pyval(pyval, typespec, ()) 854 855 @classmethod 856 def _from_pyval(cls, pyval, typespec, path_so_far): 857 """Helper function for from_pyval. 858 859 860 Args: 861 pyval: The nested Python structure that should be used to create the new 862 `StructuredTensor`. 863 typespec: A `StructuredTensorSpec` specifying the expected type for each 864 field. If not specified, then all nested dictionaries are turned into 865 StructuredTensors, and all nested lists are turned into Tensors (if 866 rank<2) or RaggedTensors (if rank>=2). 867 path_so_far: the path of fields that led here (for error messages). 868 869 Returns: 870 A `StructuredTensor`. 871 """ 872 if isinstance(pyval, dict): 873 return cls._from_pydict(pyval, typespec, path_so_far) 874 elif isinstance(pyval, (list, tuple)): 875 keys = set() 876 rank = _pyval_find_struct_keys_and_depth(pyval, keys) 877 if rank is not None: 878 return cls._from_pylist_of_dict(pyval, keys, rank, typespec, 879 path_so_far) 880 else: 881 return cls._from_pylist_of_value(pyval, typespec, path_so_far) 882 else: 883 return cls._from_pyscalar(pyval, typespec, path_so_far) 884 885 @classmethod 886 def _from_pydict(cls, pyval, typespec, path_so_far): 887 """Converts python dictionary `pyval` to a StructuredTensor with rank=0.""" 888 if typespec is None: 889 fields = dict((k, cls._from_pyval(v, None, path_so_far + (k,))) 890 for (k, v) in pyval.items()) 891 else: 892 spec_shape = typespec._shape # pylint: disable=protected-access 893 field_specs = typespec._field_specs # pylint: disable=protected-access 894 if not (isinstance(typespec, StructuredTensorSpec) and 895 spec_shape.rank == 0 and set(pyval) == set(field_specs)): 896 raise ValueError('Value at %r does not match typespec: %r vs %r' % 897 (path_so_far, pyval, typespec)) 898 fields = dict((k, cls._from_pyval(v, field_specs[k], path_so_far + (k,))) 899 for (k, v) in pyval.items()) 900 return StructuredTensor.from_fields(fields=fields, shape=(), validate=False) 901 902 @classmethod 903 def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far): 904 """Converts python list `pyval` to a StructuredTensor with rank>1.""" 905 fields = dict((key, []) for key in keys) 906 for child in pyval: 907 _pyval_update_fields(child, fields, 1) 908 if typespec is None: 909 shape = tensor_shape.TensorShape([None] * rank) 910 for (key, target) in fields.items(): 911 fields[key] = cls._from_pyval(target, None, path_so_far + (key,)) 912 else: 913 field_specs = typespec._field_specs # pylint: disable=protected-access 914 if ((not isinstance(typespec, StructuredTensorSpec)) or 915 (set(fields) - set(field_specs))): 916 raise ValueError('Value at %r does not match typespec: %r vs %r' % 917 (path_so_far, pyval, typespec)) 918 shape = typespec._shape 919 if shape.rank < rank: 920 raise ValueError('Value at %r does not match typespec (rank mismatch): ' 921 '%r vs %r' % (path_so_far, pyval, typespec)) 922 for (key, spec) in field_specs.items(): 923 fields[key] = cls._from_pyval( 924 fields.get(key, []), spec, path_so_far + (key,)) 925 try: 926 if not fields and typespec is None: 927 # TODO(b/183245576): handle cases where the typespec is known 928 # but the dictionary is empty. 929 return StructuredTensor._from_pylist_of_empty_dict(pyval, rank) 930 return StructuredTensor.from_fields( 931 fields=fields, shape=shape, validate=False) 932 except Exception as exc: 933 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 934 935 @classmethod 936 def _from_pylist_of_empty_dict(cls, pyval, rank): 937 """Converts a pylist of empty dictionaries to StructuredTensors.""" 938 if rank == 0: 939 return StructuredTensor.from_fields(fields={}, shape=(), validate=False) 940 elif rank == 1: 941 nrows = len(pyval) 942 shape = (nrows,) 943 return StructuredTensor.from_fields(fields={}, shape=shape, nrows=nrows) 944 elif rank > 1: 945 ragged_zeros = ragged_factory_ops.constant(_dicts_to_zeros(pyval)) 946 nrows = len(pyval) 947 shape = tensor_shape.TensorShape([len(pyval)] + ([None] * (rank - 1))) 948 return StructuredTensor.from_fields( 949 fields={}, 950 shape=shape, 951 row_partitions=ragged_zeros._nested_row_partitions, # pylint:disable=protected-access 952 nrows=nrows) 953 954 @classmethod 955 def _from_pylist_of_value(cls, pyval, typespec, path_so_far): 956 """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1.""" 957 if typespec is None: 958 try: 959 return ragged_factory_ops.constant(pyval) 960 except Exception as exc: 961 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 962 elif isinstance(typespec, tensor_spec.TensorSpec): 963 try: 964 result = constant_op.constant(pyval, typespec.dtype) 965 except Exception as exc: 966 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 967 if not typespec.shape.is_compatible_with(result.shape): 968 raise ValueError('Value at %r does not match typespec: %r vs %r' % 969 (path_so_far, typespec, pyval)) 970 return result 971 elif isinstance(typespec, ragged_tensor.RaggedTensorSpec): 972 # pylint: disable=protected-access 973 try: 974 return ragged_factory_ops.constant( 975 pyval, 976 dtype=typespec._dtype, 977 ragged_rank=typespec._ragged_rank, 978 row_splits_dtype=typespec._row_splits_dtype, 979 inner_shape=typespec._shape[typespec._ragged_rank + 1:]) 980 except Exception as exc: 981 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 982 elif isinstance(typespec, StructuredTensorSpec): 983 empty_rank = _pyval_empty_list_depth(pyval) 984 if empty_rank is None: 985 raise ValueError('Value at %r does not match typespec: %r vs %r' % 986 (path_so_far, typespec, pyval)) 987 else: 988 return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec, 989 path_so_far) 990 else: 991 raise ValueError('Value at %r does not match typespec: %r vs %r' % 992 (path_so_far, typespec, pyval)) 993 994 @classmethod 995 def _from_pyscalar(cls, pyval, typespec, path_so_far): 996 """Converts python scalar value `pyval` to a Tensor.""" 997 if typespec is None: 998 try: 999 return constant_op.constant(pyval) 1000 except Exception as exc: 1001 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 1002 else: 1003 if not (isinstance(typespec, tensor_spec.TensorSpec) and 1004 typespec.shape.rank == 0): 1005 raise ValueError('Value at %r does not match typespec: %r vs %r' % 1006 (path_so_far, typespec, pyval)) 1007 # TODO(edloper): Check that typespec.shape matches. 1008 return constant_op.constant(pyval, typespec.dtype) 1009 1010 #============================================================================= 1011 # Transforms 1012 #============================================================================= 1013 1014 # TODO(edloper): Add a 'validate' option here? 1015 # TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor 1016 # have a partition_outer_dimension method? 1017 def partition_outer_dimension(self, row_partition): 1018 """Partitions the outer dimension of this StructuredTensor. 1019 1020 Returns a new `StructuredTensor` with the same values as `self`, where 1021 the outer dimension is partitioned into two (possibly ragged) dimensions. 1022 Requires that this StructuredTensor have an outer dimension (i.e., 1023 `self.shape.rank > 0`). 1024 1025 >>> st = StructuredTensor.from_pyval( 1026 ... [{'foo': 12}, {'foo': 33}, {'foo': 99}]) 1027 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 1028 >>> st.partition_outer_dimension(partition) 1029 <StructuredTensor( 1030 fields={ 1031 "foo": <tf.RaggedTensor [[12, 33], [], [99]]>}, 1032 shape=(3, None))> 1033 1034 Args: 1035 row_partition: A `RowPartition`. 1036 1037 Returns: 1038 A `StructuredTensor` with rank `values.rank + 1`. 1039 """ 1040 if not isinstance(row_partition, RowPartition): 1041 raise TypeError('row_partition must be a RowPartition.') 1042 if self.shape.rank == 0: 1043 raise ValueError('Shape %s must have rank at least 1' % self.shape) 1044 return _partition_outer_dimension(self, row_partition) 1045 1046 def merge_dims(self, outer_axis, inner_axis): 1047 """Merges outer_axis...inner_axis into a single dimension. 1048 1049 Returns a copy of this RaggedTensor with the specified range of dimensions 1050 flattened into a single dimension, with elements in row-major order. 1051 1052 >>> st = StructuredTensor.from_pyval( 1053 ... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]]) 1054 >>> st.merge_dims(0, 1) 1055 <StructuredTensor( 1056 fields={ 1057 "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)}, 1058 shape=(3,))> 1059 1060 Args: 1061 outer_axis: `int`: The first dimension in the range of dimensions to 1062 merge. May be negative (to index from the last dimension). 1063 inner_axis: `int`: The last dimension in the range of dimensions to merge. 1064 May be negative (to index from the last dimension). 1065 1066 Returns: 1067 A copy of this tensor, with the specified dimensions merged into a 1068 single dimension. The shape of the returned tensor will be 1069 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 1070 is the total number of slices in the merged dimensions. 1071 """ 1072 outer_axis = array_ops.get_positive_axis( 1073 outer_axis, 1074 self.shape.rank, 1075 axis_name='outer_axis', 1076 ndims_name='rank(self)') 1077 inner_axis = array_ops.get_positive_axis( 1078 inner_axis, 1079 self.shape.rank, 1080 axis_name='inner_axis', 1081 ndims_name='rank(self)') 1082 if not outer_axis <= inner_axis: 1083 raise ValueError('Expected outer_axis (%d) to be less than or equal to ' 1084 'inner_axis (%d)' % (outer_axis, inner_axis)) 1085 return _merge_dims(self, outer_axis, inner_axis) 1086 1087 #============================================================================= 1088 # Composite Tensor 1089 #============================================================================= 1090 1091 @property 1092 def _type_spec(self): 1093 return StructuredTensorSpec.from_value(self) 1094 1095 1096@type_spec.register('tf.StructuredTensorSpec') 1097class StructuredTensorSpec(type_spec.BatchableTypeSpec): 1098 """Type specification for `StructuredTensor`s.""" 1099 1100 __slots__ = ['_shape', '_field_specs'] 1101 1102 def __init__(self, shape, field_specs): 1103 """Build a type specification for a StructuredTensor. 1104 1105 Args: 1106 shape: The shape of the StructuredTensor. shape.rank must not be None. 1107 field_specs: A dictionary mapping from field name to TypeSpec, specifying 1108 the tensor type used to encode each field. These TypeSpecs should 1109 specify the type of the entire field (including outer dimensions which 1110 correspond to `shape`). For example, if `shape=[2, 3]`, and field 'x' 1111 contains an int32 vector of size `10` for each structure, then 1112 `field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`. 1113 """ 1114 shape = tensor_shape.as_shape(shape) 1115 1116 # Perform a few sanity checks on the inputs. 1117 if shape.rank is None: 1118 raise TypeError("StructuredTensor's shape must have known rank.") 1119 if not isinstance(field_specs, dict): 1120 raise TypeError('field_specs must be a dictionary.') 1121 for key, value in field_specs.items(): 1122 if not isinstance(key, str): 1123 raise TypeError('field_specs must be a dictionary with string keys.') 1124 if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec, 1125 ragged_tensor.RaggedTensorSpec)): 1126 raise TypeError('field_specs must be a dictionary with ' 1127 'TypeSpec values.') 1128 1129 self._shape = shape 1130 self._field_specs = dict(field_specs) 1131 1132 @property 1133 def value_type(self): 1134 return StructuredTensor 1135 1136 def _to_components(self, value): 1137 nrows = () if value.nrows() is None else value.nrows() 1138 return (value._fields, nrows, value.row_partitions) 1139 1140 def _from_components(self, components): 1141 if isinstance(components, dict): 1142 logging.warning('Loading deprecated encoding for StructuredTensorSpec.') 1143 return StructuredTensor.from_fields(components, self._shape, 1144 validate=False) 1145 elif not isinstance(components[0], dict): 1146 logging.warning('Loading deprecated encoding for StructuredTensorSpec.') 1147 fields = {} 1148 nrows, row_partitions = components 1149 if isinstance(nrows, tuple) and not nrows: 1150 nrows = None # empty rank-0 structured tensor 1151 return StructuredTensor.from_fields(fields, self._shape, nrows=nrows, 1152 row_partitions=row_partitions, 1153 validate=False) 1154 1155 (fields, nrows, row_partitions) = components 1156 if isinstance(nrows, tuple) and not nrows: 1157 nrows = None # empty rank-0 structured tensor 1158 return StructuredTensor(fields, self._shape, nrows, row_partitions, 1159 internal=_structured_tensor_factory_key) 1160 1161 @property 1162 def _component_specs(self): 1163 if self._shape.rank == 0: 1164 nrows_spec = () 1165 else: 1166 nrows_spec = tensor_spec.TensorSpec([], dtypes.int64) 1167 1168 row_partition_specs = ((row_partition_lib.RowPartitionSpec(),) 1169 * (self._shape.rank - 1)) 1170 return (self._field_specs, nrows_spec, row_partition_specs) 1171 1172 @classmethod 1173 def from_value(cls, value): 1174 field_specs = dict((k, type_spec.type_spec_from_value(v)) 1175 for (k, v) in value._fields.items()) 1176 return cls(value.shape, field_specs) 1177 1178 def _serialize(self): 1179 return (self._shape, self._field_specs) 1180 1181 def _batch(self, batch_size): 1182 # pylint: disable=protected-access 1183 return StructuredTensorSpec( 1184 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 1185 dict((k, v._batch(batch_size)) for (k, v) in self._field_specs.items())) 1186 1187 def _unbatch(self): 1188 # pylint: disable=protected-access 1189 return StructuredTensorSpec( 1190 self._shape[1:], 1191 dict((k, v._unbatch()) for (k, v) in self._field_specs.items())) 1192 1193 @property 1194 def _flat_tensor_specs(self): 1195 # pylint: disable=protected-access 1196 result = [] 1197 for _, field_spec in sorted(self._field_specs.items(), key=lambda t: t[0]): 1198 result.extend(field_spec._flat_tensor_specs) 1199 return result 1200 1201 def _to_tensor_list(self, value): 1202 return self._to_tensor_list_internal(value, batched=False) 1203 1204 def _to_batched_tensor_list(self, value): 1205 return self._to_tensor_list_internal(value, batched=True) 1206 1207 def _from_compatible_tensor_list(self, tensor_list): 1208 # pylint: disable=protected-access 1209 fields = {} 1210 pos = 0 1211 for field_name, field_spec in sorted( 1212 self._field_specs.items(), key=lambda t: t[0]): 1213 num_tensors_for_field = len(field_spec._flat_tensor_specs) 1214 field_tensors = tensor_list[pos:pos + num_tensors_for_field] 1215 fields[field_name] = field_spec._from_compatible_tensor_list( 1216 field_tensors) 1217 pos += num_tensors_for_field 1218 return StructuredTensor.from_fields(fields, self._shape) 1219 1220 def _to_tensor_list_internal(self, value, batched): 1221 """Returns a dict whose entries are each field's (batched) tensor_list. 1222 1223 If a field is a StructuredTensor, then its entry will be a dict, 1224 recursively. 1225 1226 Args: 1227 value: A StructuredTensor (conforming to `self`). 1228 batched: A boolean. if True, produce `batched_tensor_list` for each field 1229 otherwise produce `tensor_list`. 1230 1231 Returns: 1232 A dict. 1233 """ 1234 result = [] 1235 for field_name, field_spec in sorted( 1236 self._field_specs.items(), key=lambda t: t[0]): 1237 # pylint: disable=protected-access 1238 field_value = value._fields[field_name] 1239 if batched: 1240 result.extend(field_spec._to_batched_tensor_list(field_value)) 1241 else: 1242 result.extend(field_spec._to_tensor_list(field_value)) 1243 1244 return result 1245 1246 1247# Regular expression used to determine whether a string is a valid field name. 1248# Note: we plan to relax (or possibly eliminate) this in the future; you 1249# should not rely on the fact that some field names are currently disallowed. 1250_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$') 1251 1252#============================================================================= 1253# Helper funtions 1254#============================================================================= 1255# TODO(edloper): Move some of these helpers to row_partition.py? 1256 1257 1258def _convert_to_structured_field_value(value): 1259 """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor.""" 1260 if isinstance(value, 1261 (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): 1262 return value 1263 elif ragged_tensor.is_ragged(value): 1264 return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 1265 else: 1266 try: 1267 return ops.convert_to_tensor(value) 1268 except (ValueError, TypeError): 1269 raise TypeError('Unexpected type for value in `fields`: %r' % value) 1270 1271 1272def _find_shape_dtype(fields, nrows, row_partitions): 1273 """Return a consistent dtype for fields, nrows, & row_partitions.""" 1274 shape_dtypes = set() 1275 for value in fields.values(): 1276 if isinstance(value, ragged_tensor.RaggedTensor): 1277 shape_dtypes.add(value.row_splits.dtype) 1278 elif isinstance(value, StructuredTensor) and value.rank > 0: 1279 shape_dtypes.add(value.nrows().dtype) 1280 if isinstance(nrows, ops.Tensor): 1281 shape_dtypes.add(nrows.dtype) 1282 if row_partitions is not None: 1283 for partition in row_partitions: 1284 shape_dtypes.add(partition.dtype) 1285 if len(shape_dtypes) > 1: 1286 raise ValueError('field values have incompatible row_partition dtypes.') 1287 elif shape_dtypes: 1288 return shape_dtypes.pop() 1289 else: 1290 return dtypes.int64 1291 1292 1293def _merge_nrows(nrows, static_nrows, value, dtype, validate): 1294 """Merges `nrows` with `nrows(value)`. 1295 1296 Checks that `value` has the expected number of rows (`nrows`), and returns 1297 `nrows`. If `validate` is true, then add validation ops that check that 1298 the `nrows` values match. 1299 1300 Args: 1301 nrows: scalar integer Tensor. 1302 static_nrows: tf.Dimension: static value of nrows, if known. 1303 value: Tensor or RaggedTensor or StructuredTensor 1304 dtype: dtype for `nrows`. 1305 validate: bool -- whether to add validation ops. 1306 1307 Returns: 1308 A tuple `(nrows, static_nrows)`. 1309 """ 1310 static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0) 1311 if isinstance(value, ops.Tensor): 1312 value_nrows = array_ops.shape(value, out_type=dtype)[0] 1313 else: 1314 value_nrows = value.nrows() 1315 if nrows is None: 1316 nrows = value_nrows 1317 elif (static_value_nrows.value is not None and 1318 static_nrows.value is not None): 1319 if not static_value_nrows.is_compatible_with(static_nrows): 1320 raise ValueError('fields have incompatible nrows') 1321 nrows = value_nrows # No need to add an assertion op. 1322 elif validate: 1323 nrows = control_flow_ops.with_dependencies([ 1324 check_ops.assert_equal( 1325 nrows, value_nrows, message='fields have incompatible nrows') 1326 ], nrows) 1327 return nrows, static_nrows.merge_with(static_value_nrows) 1328 1329 1330def _merge_row_partitions(row_partitions, value, rank, dtype, validate): 1331 """Merges `row_partitions` with `row_partitions(value)`.""" 1332 if isinstance(value, ops.Tensor): 1333 value_row_partitions = _row_partitions_for_tensor(value, rank, dtype) 1334 1335 elif isinstance(value, ragged_tensor.RaggedTensor): 1336 value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype) 1337 1338 else: 1339 assert isinstance(value, StructuredTensor), type(value) 1340 value_row_partitions = value.row_partitions[:rank - 1] 1341 1342 assert len(value_row_partitions) == rank - 1 1343 if row_partitions is None: 1344 return tuple(value_row_partitions) 1345 else: 1346 return tuple([ 1347 p1.merge_precomputed_encodings(p2, validate) 1348 for (p1, p2) in zip(row_partitions, value_row_partitions) 1349 ]) 1350 1351 1352def _row_partitions_for_tensor(value, rank, dtype): 1353 """Returns the row partitions for a tf.Tensor.""" 1354 shape = array_ops.shape(value, out_type=dtype) 1355 return _row_partitions_for_uniform_shape(shape, rank) 1356 1357 1358def _row_partitions_for_ragged_tensor(value, rank, dtype): 1359 """Returns the row partitions for a tf.RaggedTensor.""" 1360 assert rank > 1 1361 value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access 1362 if len(value_row_partitions) < (rank - 1): 1363 value_row_partitions += _row_partitions_for_tensor( 1364 value.flat_values, rank - len(value_row_partitions), dtype) 1365 assert len(value_row_partitions) == rank - 1 1366 return value_row_partitions 1367 1368 1369def _row_partitions_for_uniform_shape(shape, rank): 1370 """Returns row partitions for the given shape Tensor. 1371 1372 Args: 1373 shape: A vector describing a uniform shape. 1374 rank: The number of dimensions to generate row partitions for 1375 1376 Returns: 1377 A list of (rank-1) `RowPartition`s with uniform row length. 1378 """ 1379 shape_cumprod = math_ops.cumprod(shape[:rank]) 1380 # pylint: disable=g-complex-comprehension 1381 return tuple([ 1382 RowPartition.from_uniform_row_length( 1383 uniform_row_length=shape[i + 1], 1384 nvals=shape_cumprod[i + 1], 1385 nrows=shape_cumprod[i]) for i in range(rank - 1) 1386 ]) 1387 1388 1389def _pyval_field_major_to_node_major(keys, values, depth): 1390 """Regroup each field (k, v) from dict-of-list to list-of-dict. 1391 1392 Given a "field-major" encoding of the StructuredTensor (which maps each key to 1393 a single nested list containing the values for all structs), return a 1394 corresponding "node-major" encoding, consisting of a nested list of dicts. 1395 1396 Args: 1397 keys: The field names (list of string). Must not be empty. 1398 values: The field values (list of python values). Must have the same length 1399 as `keys`. 1400 depth: The list depth at which dictionaries should be created. 1401 1402 Returns: 1403 A nested list of dict, with depth `depth`. 1404 """ 1405 assert keys 1406 if depth == 0: 1407 return dict(zip(keys, values)) 1408 nvals = len(values[0]) 1409 assert all(nvals == len(values[i]) for i in range(1, len(values))) 1410 return [ 1411 _pyval_field_major_to_node_major(keys, value_slice, depth - 1) 1412 for value_slice in zip(*values) 1413 ] 1414 1415 1416def _empty_dict_pylist_from_row_partitions(row_partitions, nrows): 1417 """Returns a python list of empty dicts from the given row partitions. 1418 1419 Args: 1420 row_partitions: The row-partitions describing the ragged shape of the 1421 result. 1422 nrows: The number of rows in the outermost row-partition. (Or if 1423 `len(row_partitions)==0`, then the number of empty dicts to return.) 1424 1425 Returns: 1426 A nested python list whose leaves (if any) are empty python dicts. 1427 """ 1428 if not row_partitions: 1429 return [{} for _ in range(nrows)] 1430 else: 1431 values = _empty_dict_pylist_from_row_partitions( 1432 row_partitions[1:], row_partitions[0].row_splits()[-1]) 1433 splits = row_partitions[0].row_splits() 1434 return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] 1435 1436 1437def _pyval_find_struct_keys_and_depth(pyval, keys): 1438 """Finds the keys & depth of nested dictionaries in `pyval`. 1439 1440 Args: 1441 pyval: A nested structure of lists, tuples, and dictionaries. 1442 keys: (output parameter) A set, which will be updated with any keys that are 1443 found in the nested dictionaries. 1444 1445 Returns: 1446 The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does 1447 not contain any dictionaries. 1448 Raises: 1449 ValueError: If dictionaries have inconsistent depth. 1450 """ 1451 if isinstance(pyval, dict): 1452 keys.update(pyval.keys()) 1453 return 0 1454 elif isinstance(pyval, (list, tuple)): 1455 depth = None 1456 for child in pyval: 1457 child_depth = _pyval_find_struct_keys_and_depth(child, keys) 1458 if child_depth is not None: 1459 if depth is None: 1460 depth = child_depth + 1 1461 elif depth != child_depth + 1: 1462 raise ValueError('Inconsistent depth of dictionaries') 1463 return depth 1464 else: 1465 return None 1466 1467 1468def _pyval_update_fields(pyval, fields, depth): 1469 """Append the field values from `pyval` to `fields`. 1470 1471 Args: 1472 pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s) 1473 should be appended to `fields`. 1474 fields: A dictionary mapping string keys to field values. Field values 1475 extracted from `pyval` are appended to this dictionary's values. 1476 depth: The depth at which `pyval` should be appended to the field values. 1477 """ 1478 if not isinstance(pyval, (dict, list, tuple)): 1479 raise ValueError('Expected dict or nested list/tuple of dict') 1480 1481 for (key, target) in fields.items(): 1482 for _ in range(1, depth): 1483 target = target[-1] 1484 target.append(pyval[key] if isinstance(pyval, dict) else []) 1485 1486 if isinstance(pyval, (list, tuple)): 1487 for child in pyval: 1488 _pyval_update_fields(child, fields, depth + 1) 1489 1490 1491def _pyval_empty_list_depth(pyval): 1492 """Find the max depth for nested empty lists. 1493 1494 Args: 1495 pyval: A nested python list. 1496 1497 Returns: 1498 The maximum depth of empty lists in `pyval`, or None if `pyval` contains 1499 anything other than nested empty lists. 1500 """ 1501 if isinstance(pyval, list): 1502 if not pyval: 1503 return 1 1504 depths = [_pyval_empty_list_depth(v) for v in pyval] 1505 if any(depth is None for depth in depths): 1506 return None 1507 else: 1508 return max(depths) + 1 1509 else: 1510 return None 1511 1512 1513def _replace_row_partitions(value, new_partitions): 1514 """Updates `value` to use `new_partitions` as its (outer) row partitions. 1515 1516 This is used to ensure that all fields in a `StructuredTensor` use identical 1517 `RowPartition` objects for the shared dimensions. In particular, 1518 `StructuredTensor.from_fields` first merges all of the row partitions from 1519 any fields, and then replaces the outer row partitions of all fields with 1520 the merged row partitions (using this function). 1521 1522 Args: 1523 value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`. 1524 new_partitions: A list of row-partitions that should be used by `value`. 1525 Must be equivalent to `value`'s current row partitions. 1526 1527 Returns: 1528 A value that is equivalent to `value`, where outer row partitions have been 1529 replaced by `new_partitions`. 1530 """ 1531 if isinstance(value, ops.Tensor) or not new_partitions: 1532 return value 1533 1534 elif isinstance(value, ragged_tensor.RaggedTensor): 1535 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 1536 values=_replace_row_partitions(value.values, new_partitions[1:]), 1537 row_partition=new_partitions[0]) 1538 1539 else: 1540 assert isinstance(value, StructuredTensor) 1541 new_fields = dict((k, _replace_row_partitions(v, new_partitions)) 1542 for (k, v) in value._fields.items()) 1543 return StructuredTensor( 1544 fields=new_fields, 1545 shape=value.shape, 1546 nrows=value.nrows(), 1547 row_partitions=new_partitions + 1548 value.row_partitions[len(new_partitions):], 1549 internal=_structured_tensor_factory_key) 1550 1551 1552def _partition_outer_dimension(value, row_partition): 1553 """Partitions the outer dimension of `value` using `row_partitions`. 1554 1555 Examples: 1556 1557 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 1558 >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition) 1559 <tf.RaggedTensor [[1, 2], [], [3]]> 1560 1561 >>> struct_value = StructuredTensor.from_pyval( 1562 ... [{'x': 1}, {'x': 2}, {'x': 3}]) 1563 >>> _partition_outer_dimension(struct_value, partition) 1564 <StructuredTensor( 1565 fields={ 1566 "x": <tf.RaggedTensor [[1, 2], [], [3]]>}, 1567 shape=(3, None))> 1568 1569 Args: 1570 value: Tensor, RaggedTensor, or StructuredTensor 1571 row_partition: RowPartition 1572 1573 Returns: 1574 A value with the same type as `value`, where 1575 `result.rank = value.rank + 1`. 1576 """ 1577 is_ragged = row_partition.uniform_row_length() is None 1578 if isinstance(value, ops.Tensor) and not is_ragged: 1579 new_shape = array_ops.concat( 1580 [[row_partition.nrows(), 1581 row_partition.uniform_row_length()], 1582 array_ops.shape(value, out_type=row_partition.dtype)[1:]], 1583 axis=0) 1584 return array_ops.reshape(value, new_shape) 1585 elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 1586 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 1587 value, row_partition) 1588 else: 1589 assert isinstance(value, StructuredTensor) 1590 nrows = row_partition.static_nrows 1591 ncols = row_partition.static_uniform_row_length 1592 shape = tensor_shape.TensorShape([nrows, 1593 ncols]).concatenate(value.shape[1:]) 1594 fields = dict((k, _partition_outer_dimension(v, row_partition)) 1595 for (k, v) in value._fields.items()) 1596 return StructuredTensor( 1597 fields, 1598 shape, 1599 row_partition.nrows(), (row_partition,) + value.row_partitions, 1600 internal=_structured_tensor_factory_key) 1601 1602 1603def _merge_dims(value, outer_axis, inner_axis): 1604 """Merges `outer_axis...inner_axis` of `value` into a single dimension.""" 1605 assert outer_axis < inner_axis 1606 if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 1607 return ragged_tensor.merge_dims(value, outer_axis, inner_axis) 1608 else: 1609 assert isinstance(value, StructuredTensor) 1610 1611 # Build the new fields. 1612 fields = dict((k, _merge_dims(v, outer_axis, inner_axis)) 1613 for (k, v) in value._fields.items()) 1614 1615 # Build the new shape. 1616 value_shape = value.shape 1617 shape = ( 1618 value_shape[:outer_axis] + 1619 [value_shape[outer_axis:inner_axis].num_elements()] + 1620 value_shape[inner_axis + 1:]) 1621 1622 # Build the new row_partitions & nrows 1623 if outer_axis == 0: 1624 if inner_axis == value.shape.rank - 1: 1625 partitions = () 1626 nrows = value.row_partitions[-1].nvals() 1627 else: 1628 partitions = value.row_partitions[inner_axis:] 1629 nrows = partitions[0].nrows() 1630 else: 1631 # Use tf.gather to merge row_splits from the merged row partitions. 1632 merged_splits = value.row_partitions[outer_axis - 1].row_splits() 1633 for dim in range(outer_axis, inner_axis): 1634 merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(), 1635 merged_splits) 1636 1637 partitions = ( 1638 value.row_partitions[:outer_axis - 1] + 1639 (RowPartition.from_row_splits(merged_splits),) + 1640 value.row_partitions[inner_axis:]) 1641 nrows = partitions[0].nrows() 1642 1643 return StructuredTensor( 1644 fields, 1645 shape, 1646 nrows, 1647 partitions, 1648 internal=_structured_tensor_factory_key) 1649 1650 1651_structured_tensor_factory_key = object() # unique private object 1652 1653 1654def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]: 1655 """FieldName can be given also as string, this normalizes it to a tuple.""" 1656 if isinstance(name, str): 1657 return (name,) 1658 if isinstance(name, list): 1659 return tuple(name) 1660 assert isinstance(name, tuple) 1661 return name 1662 1663 1664def _dicts_to_zeros(pyval): 1665 """Replaces dictionaries zeros in a pylist.""" 1666 if isinstance(pyval, dict): 1667 return 0 1668 return [_dicts_to_zeros(x) for x in pyval] 1669 1670 1671def _merge_dims_generic(source, outer, inner): 1672 """Merges outer_axis...inner_axis into a single dimension. 1673 1674 If outer == inner, this is a NOOP. If inner < outer, then this fials. 1675 If inner >= source.shape.rank, then the behavior is undefined. 1676 1677 Args: 1678 source: a tensor, ragged tensor, or structured tensor. 1679 outer: a python int, indicating the first dimension to compress (must be 1680 nonnegative). 1681 inner: a python int, indicating the first dimension to keep (of the tail) 1682 (must be nonnegative). 1683 1684 Returns: 1685 source with outer_axis...inner_axis merged into a single dimension. 1686 1687 """ 1688 if isinstance(source, StructuredTensor): 1689 return source.merge_dims(outer, inner) 1690 else: 1691 return ragged_tensor.merge_dims(source, outer, inner) 1692