• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementations of different data feeders to provide data for TF trainer (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import itertools
29import math
30
31import numpy as np
32import six
33from six.moves import xrange  # pylint: disable=redefined-builtin
34
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.ops import array_ops
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.util.deprecation import deprecated
40
41# pylint: disable=g-multiple-import,g-bad-import-order
42from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
43from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
44
45# pylint: enable=g-multiple-import,g-bad-import-order
46
47
48def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
49  """Returns shape for input and output of the data feeder."""
50  x_is_dict, y_is_dict = isinstance(
51      x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
52  if y_is_dict and n_classes is not None:
53    assert isinstance(n_classes, dict)
54
55  if batch_size is None:
56    batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
57  elif batch_size <= 0:
58    raise ValueError('Invalid batch_size %d.' % batch_size)
59
60  if x_is_dict:
61    input_shape = {}
62    for k, v in list(x_shape.items()):
63      input_shape[k] = [batch_size] + (list(v[1:]) if len(v) > 1 else [1])
64  else:
65    x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
66    input_shape = [batch_size] + x_shape
67
68  if y_shape is None:
69    return input_shape, None, batch_size
70
71  def out_el_shape(out_shape, num_classes):
72    out_shape = list(out_shape[1:]) if len(out_shape) > 1 else []
73    # Skip first dimension if it is 1.
74    if out_shape and out_shape[0] == 1:
75      out_shape = out_shape[1:]
76    if num_classes is not None and num_classes > 1:
77      return [batch_size] + out_shape + [num_classes]
78    else:
79      return [batch_size] + out_shape
80
81  if not y_is_dict:
82    output_shape = out_el_shape(y_shape, n_classes)
83  else:
84    output_shape = dict([(k,
85                          out_el_shape(v, n_classes[k]
86                                       if n_classes is not None and
87                                       k in n_classes else None))
88                         for k, v in list(y_shape.items())])
89
90  return input_shape, output_shape, batch_size
91
92
93def _data_type_filter(x, y):
94  """Filter data types into acceptable format."""
95  if HAS_DASK:
96    x = extract_dask_data(x)
97    if y is not None:
98      y = extract_dask_labels(y)
99  if HAS_PANDAS:
100    x = extract_pandas_data(x)
101    if y is not None:
102      y = extract_pandas_labels(y)
103  return x, y
104
105
106def _is_iterable(x):
107  return hasattr(x, 'next') or hasattr(x, '__next__')
108
109
110@deprecated(None, 'Please use tensorflow/transform or tf.data.')
111def setup_train_data_feeder(x,
112                            y,
113                            n_classes,
114                            batch_size=None,
115                            shuffle=True,
116                            epochs=None):
117  """Create data feeder, to sample inputs from dataset.
118
119  If `x` and `y` are iterators, use `StreamingDataFeeder`.
120
121  Args:
122    x: numpy, pandas or Dask matrix or dictionary of aforementioned. Also
123      supports iterables.
124    y: numpy, pandas or Dask array or dictionary of aforementioned. Also
125      supports
126      iterables.
127    n_classes: number of classes. Must be None or same type as y. In case, `y`
128      is `dict`
129      (or iterable which returns dict) such that `n_classes[key] = n_classes for
130        y[key]`
131    batch_size: size to split data into parts. Must be >= 1.
132    shuffle: Whether to shuffle the inputs.
133    epochs: Number of epochs to run.
134
135  Returns:
136    DataFeeder object that returns training data.
137
138  Raises:
139    ValueError: if one of `x` and `y` is iterable and the other is not.
140  """
141  x, y = _data_type_filter(x, y)
142  if HAS_DASK:
143    # pylint: disable=g-import-not-at-top
144    import dask.dataframe as dd
145    if (isinstance(x, (dd.Series, dd.DataFrame)) and
146        (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
147      data_feeder_cls = DaskDataFeeder
148    else:
149      data_feeder_cls = DataFeeder
150  else:
151    data_feeder_cls = DataFeeder
152
153  if _is_iterable(x):
154    if y is not None and not _is_iterable(y):
155      raise ValueError('Both x and y should be iterators for '
156                       'streaming learning to work.')
157    return StreamingDataFeeder(x, y, n_classes, batch_size)
158  return data_feeder_cls(
159      x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
160
161
162def _batch_data(x, batch_size=None):
163  if (batch_size is not None) and (batch_size <= 0):
164    raise ValueError('Invalid batch_size %d.' % batch_size)
165
166  x_first_el = six.next(x)
167  x = itertools.chain([x_first_el], x)
168
169  chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(
170      x_first_el, dict) else []
171  chunk_filled = False
172  for data in x:
173    if isinstance(data, dict):
174      for k, v in list(data.items()):
175        chunk[k].append(v)
176        if (batch_size is not None) and (len(chunk[k]) >= batch_size):
177          chunk[k] = np.matrix(chunk[k])
178          chunk_filled = True
179      if chunk_filled:
180        yield chunk
181        chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(
182            x_first_el, dict) else []
183        chunk_filled = False
184    else:
185      chunk.append(data)
186      if (batch_size is not None) and (len(chunk) >= batch_size):
187        yield np.matrix(chunk)
188        chunk = []
189
190  if isinstance(x_first_el, dict):
191    for k, v in list(data.items()):
192      chunk[k] = np.matrix(chunk[k])
193    yield chunk
194  else:
195    yield np.matrix(chunk)
196
197
198@deprecated(None, 'Please use tensorflow/transform or tf.data.')
199def setup_predict_data_feeder(x, batch_size=None):
200  """Returns an iterable for feeding into predict step.
201
202  Args:
203    x: numpy, pandas, Dask array or dictionary of aforementioned. Also supports
204      iterable.
205    batch_size: Size of batches to split data into. If `None`, returns one
206      batch of full size.
207
208  Returns:
209    List or iterator (or dictionary thereof) of parts of data to predict on.
210
211  Raises:
212    ValueError: if `batch_size` <= 0.
213  """
214  if HAS_DASK:
215    x = extract_dask_data(x)
216  if HAS_PANDAS:
217    x = extract_pandas_data(x)
218  if _is_iterable(x):
219    return _batch_data(x, batch_size)
220  if len(x.shape) == 1:
221    x = np.reshape(x, (-1, 1))
222  if batch_size is not None:
223    if batch_size <= 0:
224      raise ValueError('Invalid batch_size %d.' % batch_size)
225    n_batches = int(math.ceil(float(len(x)) / batch_size))
226    return [x[i * batch_size:(i + 1) * batch_size] for i in xrange(n_batches)]
227  return [x]
228
229
230@deprecated(None, 'Please use tensorflow/transform or tf.data.')
231def setup_processor_data_feeder(x):
232  """Sets up processor iterable.
233
234  Args:
235    x: numpy, pandas or iterable.
236
237  Returns:
238    Iterable of data to process.
239  """
240  if HAS_PANDAS:
241    x = extract_pandas_matrix(x)
242  return x
243
244
245@deprecated(None, 'Please convert numpy dtypes explicitly.')
246def check_array(array, dtype):
247  """Checks array on dtype and converts it if different.
248
249  Args:
250    array: Input array.
251    dtype: Expected dtype.
252
253  Returns:
254    Original array or converted.
255  """
256  # skip check if array is instance of other classes, e.g. h5py.Dataset
257  # to avoid copying array and loading whole data into memory
258  if isinstance(array, (np.ndarray, list)):
259    array = np.array(array, dtype=dtype, order=None, copy=False)
260  return array
261
262
263def _access(data, iloc):
264  """Accesses an element from collection, using integer location based indexing.
265
266  Args:
267    data: array-like. The collection to access
268    iloc: `int` or `list` of `int`s. Location(s) to access in `collection`
269
270  Returns:
271    The element of `a` found at location(s) `iloc`.
272  """
273  if HAS_PANDAS:
274    import pandas as pd  # pylint: disable=g-import-not-at-top
275    if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame):
276      return data.iloc[iloc]
277  return data[iloc]
278
279
280def _check_dtype(dtype):
281  if dtypes.as_dtype(dtype) == dtypes.float64:
282    logging.warn(
283        'float64 is not supported by many models, consider casting to float32.')
284  return dtype
285
286
287class DataFeeder(object):
288  """Data feeder is an example class to sample data for TF trainer.
289
290  THIS CLASS IS DEPRECATED. See
291  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
292  for general migration instructions.
293  """
294
295  @deprecated(None, 'Please use tensorflow/transform or tf.data.')
296  def __init__(self,
297               x,
298               y,
299               n_classes,
300               batch_size=None,
301               shuffle=True,
302               random_state=None,
303               epochs=None):
304    """Initializes a DataFeeder instance.
305
306    Args:
307      x: One feature sample which can either Nd numpy matrix of shape
308        `[n_samples, n_features, ...]` or dictionary of Nd numpy matrix.
309      y: label vector, either floats for regression or class id for
310        classification. If matrix, will consider as a sequence of labels.
311        Can be `None` for unsupervised setting. Also supports dictionary of
312        labels.
313      n_classes: Number of classes, 0 and 1 are considered regression, `None`
314        will pass through the input labels without one-hot conversion. Also, if
315        `y` is `dict`, then `n_classes` must be `dict` such that
316        `n_classes[key] = n_classes for label y[key]`, `None` otherwise.
317      batch_size: Mini-batch size to accumulate samples in one mini batch.
318      shuffle: Whether to shuffle `x`.
319      random_state: Numpy `RandomState` object to reproduce sampling.
320      epochs: Number of times to iterate over input data before raising
321        `StopIteration` exception.
322
323    Attributes:
324      x: Input features (ndarray or dictionary of ndarrays).
325      y: Input label (ndarray or dictionary of ndarrays).
326      n_classes: Number of classes (if `None`, pass through indices without
327        one-hot conversion).
328      batch_size: Mini-batch size to accumulate.
329      input_shape: Shape of the input (or dictionary of shapes).
330      output_shape: Shape of the output (or dictionary of shapes).
331      input_dtype: DType of input (or dictionary of shapes).
332      output_dtype: DType of output (or dictionary of shapes.
333    """
334    x_is_dict, y_is_dict = isinstance(
335        x, dict), y is not None and isinstance(y, dict)
336    if isinstance(y, list):
337      y = np.array(y)
338
339    self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
340                   ]) if x_is_dict else check_array(x, x.dtype)
341    self._y = None if y is None else (dict(
342        [(k, check_array(v, v.dtype)) for k, v in list(y.items())])
343                                      if y_is_dict else check_array(y, y.dtype))
344
345    # self.n_classes is not None means we're converting raw target indices
346    # to one-hot.
347    if n_classes is not None:
348      if not y_is_dict:
349        y_dtype = (
350            np.int64 if n_classes is not None and n_classes > 1 else np.float32)
351        self._y = (None if y is None else check_array(y, dtype=y_dtype))
352
353    self.n_classes = n_classes
354    self.max_epochs = epochs
355
356    x_shape = dict([(k, v.shape) for k, v in list(self._x.items())
357                   ]) if x_is_dict else self._x.shape
358    y_shape = dict([(k, v.shape) for k, v in list(self._y.items())
359                   ]) if y_is_dict else None if y is None else self._y.shape
360
361    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
362        x_shape, y_shape, n_classes, batch_size)
363
364    # Input dtype matches dtype of x.
365    self._input_dtype = (
366        dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())])
367        if x_is_dict else _check_dtype(self._x.dtype))
368
369    # self._output_dtype == np.float32 when y is None
370    self._output_dtype = (
371        dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())])
372        if y_is_dict else (_check_dtype(self._y.dtype)
373                           if y is not None else np.float32))
374
375    # self.n_classes is None means we're passing in raw target indices
376    if n_classes is not None and y_is_dict:
377      for key in list(n_classes.keys()):
378        if key in self._output_dtype:
379          self._output_dtype[key] = np.float32
380
381    self._shuffle = shuffle
382    self.random_state = np.random.RandomState(
383        42) if random_state is None else random_state
384
385    if x_is_dict:
386      num_samples = list(self._x.values())[0].shape[0]
387    elif tensor_util.is_tensor(self._x):
388      num_samples = self._x.shape[
389          0].value  # shape will be a Dimension, extract an int
390    else:
391      num_samples = self._x.shape[0]
392
393    if self._shuffle:
394      self.indices = self.random_state.permutation(num_samples)
395    else:
396      self.indices = np.array(range(num_samples))
397    self.offset = 0
398    self.epoch = 0
399    self._epoch_placeholder = None
400
401  @property
402  def x(self):
403    return self._x
404
405  @property
406  def y(self):
407    return self._y
408
409  @property
410  def shuffle(self):
411    return self._shuffle
412
413  @property
414  def input_dtype(self):
415    return self._input_dtype
416
417  @property
418  def output_dtype(self):
419    return self._output_dtype
420
421  @property
422  def batch_size(self):
423    return self._batch_size
424
425  def make_epoch_variable(self):
426    """Adds a placeholder variable for the epoch to the graph.
427
428    Returns:
429      The epoch placeholder.
430    """
431    self._epoch_placeholder = array_ops.placeholder(
432        dtypes.int32, [1], name='epoch')
433    return self._epoch_placeholder
434
435  def input_builder(self):
436    """Builds inputs in the graph.
437
438    Returns:
439      Two placeholders for inputs and outputs.
440    """
441
442    def get_placeholder(shape, dtype, name_prepend):
443      if shape is None:
444        return None
445      if isinstance(shape, dict):
446        placeholder = {}
447        for key in list(shape.keys()):
448          placeholder[key] = array_ops.placeholder(
449              dtypes.as_dtype(dtype[key]), [None] + shape[key][1:],
450              name=name_prepend + '_' + key)
451      else:
452        placeholder = array_ops.placeholder(
453            dtypes.as_dtype(dtype), [None] + shape[1:], name=name_prepend)
454      return placeholder
455
456    self._input_placeholder = get_placeholder(self.input_shape,
457                                              self._input_dtype, 'input')
458    self._output_placeholder = get_placeholder(self.output_shape,
459                                               self._output_dtype, 'output')
460    return self._input_placeholder, self._output_placeholder
461
462  def set_placeholders(self, input_placeholder, output_placeholder):
463    """Sets placeholders for this data feeder.
464
465    Args:
466      input_placeholder: Placeholder for `x` variable. Should match shape
467        of the examples in the x dataset.
468      output_placeholder: Placeholder for `y` variable. Should match
469        shape of the examples in the y dataset. Can be `None`.
470    """
471    self._input_placeholder = input_placeholder
472    self._output_placeholder = output_placeholder
473
474  def get_feed_params(self):
475    """Function returns a `dict` with data feed params while training.
476
477    Returns:
478      A `dict` with data feed params while training.
479    """
480    return {
481        'epoch': self.epoch,
482        'offset': self.offset,
483        'batch_size': self._batch_size
484    }
485
486  def get_feed_dict_fn(self):
487    """Returns a function that samples data into given placeholders.
488
489    Returns:
490      A function that when called samples a random subset of batch size
491      from `x` and `y`.
492    """
493    x_is_dict, y_is_dict = isinstance(
494        self._x, dict), self._y is not None and isinstance(self._y, dict)
495
496    # Assign input features from random indices.
497    def extract(data, indices):
498      return (np.array(_access(data, indices)).reshape((indices.shape[0], 1))
499              if len(data.shape) == 1 else _access(data, indices))
500
501    # assign labels from random indices
502    def assign_label(data, shape, dtype, n_classes, indices):
503      shape[0] = indices.shape[0]
504      out = np.zeros(shape, dtype=dtype)
505      for i in xrange(out.shape[0]):
506        sample = indices[i]
507        # self.n_classes is None means we're passing in raw target indices
508        if n_classes is None:
509          out[i] = _access(data, sample)
510        else:
511          if n_classes > 1:
512            if len(shape) == 2:
513              out.itemset((i, int(_access(data, sample))), 1.0)
514            else:
515              for idx, value in enumerate(_access(data, sample)):
516                out.itemset(tuple([i, idx, value]), 1.0)
517          else:
518            out[i] = _access(data, sample)
519      return out
520
521    def _feed_dict_fn():
522      """Function that samples data into given placeholders."""
523      if self.max_epochs is not None and self.epoch + 1 > self.max_epochs:
524        raise StopIteration
525      assert self._input_placeholder is not None
526      feed_dict = {}
527      if self._epoch_placeholder is not None:
528        feed_dict[self._epoch_placeholder.name] = [self.epoch]
529
530      # Take next batch of indices.
531      x_len = list(
532          self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
533      end = min(x_len, self.offset + self._batch_size)
534      batch_indices = self.indices[self.offset:end]
535
536      # adding input placeholder
537      feed_dict.update(
538          dict([(self._input_placeholder[k].name, extract(v, batch_indices))
539                for k, v in list(self._x.items())]) if x_is_dict else {
540                    self._input_placeholder.name:
541                        extract(self._x, batch_indices)
542                })
543
544      # move offset and reset it if necessary
545      self.offset += self._batch_size
546      if self.offset >= x_len:
547        self.indices = self.random_state.permutation(
548            x_len) if self._shuffle else np.array(range(x_len))
549        self.offset = 0
550        self.epoch += 1
551
552      # return early if there are no labels
553      if self._output_placeholder is None:
554        return feed_dict
555
556      # adding output placeholders
557      if y_is_dict:
558        for k, v in list(self._y.items()):
559          n_classes = (self.n_classes[k] if k in self.n_classes else
560                       None) if self.n_classes is not None else None
561          shape, dtype = self.output_shape[k], self._output_dtype[k]
562          feed_dict.update({
563              self._output_placeholder[k].name:
564                  assign_label(v, shape, dtype, n_classes, batch_indices)
565          })
566      else:
567        shape, dtype, n_classes = (self.output_shape, self._output_dtype,
568                                   self.n_classes)
569        feed_dict.update({
570            self._output_placeholder.name:
571                assign_label(self._y, shape, dtype, n_classes, batch_indices)
572        })
573
574      return feed_dict
575
576    return _feed_dict_fn
577
578
579class StreamingDataFeeder(DataFeeder):
580  """Data feeder for TF trainer that reads data from iterator.
581
582  THIS CLASS IS DEPRECATED. See
583  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
584  for general migration instructions.
585
586  Streaming data feeder allows to read data as it comes it from disk or
587  somewhere else. It's custom to have this iterators rotate infinetly over
588  the dataset, to allow control of how much to learn on the trainer side.
589  """
590
591  def __init__(self, x, y, n_classes, batch_size):
592    """Initializes a StreamingDataFeeder instance.
593
594    Args:
595      x: iterator each element of which returns one feature sample. Sample can
596        be a Nd numpy matrix or dictionary of Nd numpy matrices.
597      y: iterator each element of which returns one label sample. Sample can be
598        a Nd numpy matrix or dictionary of Nd numpy matrices with 1 or many
599        classes regression values.
600      n_classes: indicator of how many classes the corresponding label sample
601        has for the purposes of one-hot conversion of label. In case where `y`
602        is a dictionary, `n_classes` must be dictionary (with same keys as `y`)
603        of how many classes there are in each label in `y`. If key is
604        present in `y` and missing in `n_classes`, the value is assumed `None`
605        and no one-hot conversion will be applied to the label with that key.
606      batch_size: Mini batch size to accumulate samples in one batch. If set
607        `None`, then assumes that iterator to return already batched element.
608
609    Attributes:
610      x: input features (or dictionary of input features).
611      y: input label (or dictionary of output features).
612      n_classes: number of classes.
613      batch_size: mini batch size to accumulate.
614      input_shape: shape of the input (can be dictionary depending on `x`).
615      output_shape: shape of the output (can be dictionary depending on `y`).
616      input_dtype: dtype of input (can be dictionary depending on `x`).
617      output_dtype: dtype of output (can be dictionary depending on `y`).
618    """
619    # pylint: disable=invalid-name,super-init-not-called
620    x_first_el = six.next(x)
621    self._x = itertools.chain([x_first_el], x)
622    if y is not None:
623      y_first_el = six.next(y)
624      self._y = itertools.chain([y_first_el], y)
625    else:
626      y_first_el = None
627      self._y = None
628    self.n_classes = n_classes
629
630    x_is_dict = isinstance(x_first_el, dict)
631    y_is_dict = y is not None and isinstance(y_first_el, dict)
632    if y_is_dict and n_classes is not None:
633      assert isinstance(n_classes, dict)
634
635    # extract shapes for first_elements
636    if x_is_dict:
637      x_first_el_shape = dict(
638          [(k, [1] + list(v.shape)) for k, v in list(x_first_el.items())])
639    else:
640      x_first_el_shape = [1] + list(x_first_el.shape)
641
642    if y_is_dict:
643      y_first_el_shape = dict(
644          [(k, [1] + list(v.shape)) for k, v in list(y_first_el.items())])
645    elif y is None:
646      y_first_el_shape = None
647    else:
648      y_first_el_shape = (
649          [1] + list(y_first_el[0].shape
650                     if isinstance(y_first_el, list) else y_first_el.shape))
651
652    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
653        x_first_el_shape, y_first_el_shape, n_classes, batch_size)
654
655    # Input dtype of x_first_el.
656    if x_is_dict:
657      self._input_dtype = dict(
658          [(k, _check_dtype(v.dtype)) for k, v in list(x_first_el.items())])
659    else:
660      self._input_dtype = _check_dtype(x_first_el.dtype)
661
662    # Output dtype of y_first_el.
663    def check_y_dtype(el):
664      if isinstance(el, np.ndarray):
665        return el.dtype
666      elif isinstance(el, list):
667        return check_y_dtype(el[0])
668      else:
669        return _check_dtype(np.dtype(type(el)))
670
671    # Output types are floats, due to both softmaxes and regression req.
672    if n_classes is not None and (y is None or not y_is_dict) and n_classes > 0:
673      self._output_dtype = np.float32
674    elif y_is_dict:
675      self._output_dtype = dict(
676          [(k, check_y_dtype(v)) for k, v in list(y_first_el.items())])
677    elif y is None:
678      self._output_dtype = None
679    else:
680      self._output_dtype = check_y_dtype(y_first_el)
681
682  def get_feed_params(self):
683    """Function returns a `dict` with data feed params while training.
684
685    Returns:
686      A `dict` with data feed params while training.
687    """
688    return {'batch_size': self._batch_size}
689
690  def get_feed_dict_fn(self):
691    """Returns a function, that will sample data and provide it to placeholders.
692
693    Returns:
694      A function that when called samples a random subset of batch size
695      from x and y.
696    """
697    self.stopped = False
698
699    def _feed_dict_fn():
700      """Samples data and provides it to placeholders.
701
702      Returns:
703        `dict` of input and output tensors.
704      """
705
706      def init_array(shape, dtype):
707        """Initialize array of given shape or dict of shapes and dtype."""
708        if shape is None:
709          return None
710        elif isinstance(shape, dict):
711          return dict(
712              [(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())])
713        else:
714          return np.zeros(shape, dtype=dtype)
715
716      def put_data_array(dest, index, source=None, n_classes=None):
717        """Puts data array into container."""
718        if source is None:
719          dest = dest[:index]
720        elif n_classes is not None and n_classes > 1:
721          if len(self.output_shape) == 2:
722            dest.itemset((index, source), 1.0)
723          else:
724            for idx, value in enumerate(source):
725              dest.itemset(tuple([index, idx, value]), 1.0)
726        else:
727          if len(dest.shape) > 1:
728            dest[index, :] = source
729          else:
730            dest[index] = source[0] if isinstance(source, list) else source
731        return dest
732
733      def put_data_array_or_dict(holder, index, data=None, n_classes=None):
734        """Puts data array or data dictionary into container."""
735        if holder is None:
736          return None
737        if isinstance(holder, dict):
738          if data is None:
739            data = {k: None for k in holder.keys()}
740          assert isinstance(data, dict)
741          for k in holder.keys():
742            num_classes = n_classes[k] if (n_classes is not None and
743                                           k in n_classes) else None
744            holder[k] = put_data_array(holder[k], index, data[k], num_classes)
745        else:
746          holder = put_data_array(holder, index, data, n_classes)
747        return holder
748
749      if self.stopped:
750        raise StopIteration
751
752      inp = init_array(self.input_shape, self._input_dtype)
753      out = init_array(self.output_shape, self._output_dtype)
754
755      for i in xrange(self._batch_size):
756        # Add handling when queue ends.
757        try:
758          next_inp = six.next(self._x)
759          inp = put_data_array_or_dict(inp, i, next_inp, None)
760        except StopIteration:
761          self.stopped = True
762          if i == 0:
763            raise
764          inp = put_data_array_or_dict(inp, i, None, None)
765          out = put_data_array_or_dict(out, i, None, None)
766          break
767
768        if self._y is not None:
769          next_out = six.next(self._y)
770          out = put_data_array_or_dict(out, i, next_out, self.n_classes)
771
772      # creating feed_dict
773      if isinstance(inp, dict):
774        feed_dict = dict([(self._input_placeholder[k].name, inp[k])
775                          for k in list(self._input_placeholder.keys())])
776      else:
777        feed_dict = {self._input_placeholder.name: inp}
778      if self._y is not None:
779        if isinstance(out, dict):
780          feed_dict.update(
781              dict([(self._output_placeholder[k].name, out[k])
782                    for k in list(self._output_placeholder.keys())]))
783        else:
784          feed_dict.update({self._output_placeholder.name: out})
785
786      return feed_dict
787
788    return _feed_dict_fn
789
790
791class DaskDataFeeder(object):
792  """Data feeder for that reads data from dask.Series and dask.DataFrame.
793
794  THIS CLASS IS DEPRECATED. See
795  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
796  for general migration instructions.
797
798  Numpy arrays can be serialized to disk and it's possible to do random seeks
799  into them. DaskDataFeeder will remove requirement to have full dataset in the
800  memory and still do random seeks for sampling of batches.
801  """
802
803  @deprecated(None, 'Please feed input to tf.data to support dask.')
804  def __init__(self,
805               x,
806               y,
807               n_classes,
808               batch_size,
809               shuffle=True,
810               random_state=None,
811               epochs=None):
812    """Initializes a DaskDataFeeder instance.
813
814    Args:
815      x: iterator that returns for each element, returns features.
816      y: iterator that returns for each element, returns 1 or many classes /
817        regression values.
818      n_classes: indicator of how many classes the label has.
819      batch_size: Mini batch size to accumulate.
820      shuffle: Whether to shuffle the inputs.
821      random_state: random state for RNG. Note that it will mutate so use a
822        int value for this if you want consistent sized batches.
823      epochs: Number of epochs to run.
824
825    Attributes:
826      x: input features.
827      y: input label.
828      n_classes: number of classes.
829      batch_size: mini batch size to accumulate.
830      input_shape: shape of the input.
831      output_shape: shape of the output.
832      input_dtype: dtype of input.
833      output_dtype: dtype of output.
834
835    Raises:
836      ValueError: if `x` or `y` are `dict`, as they are not supported currently.
837    """
838
839    if isinstance(x, dict) or isinstance(y, dict):
840      raise ValueError(
841          'DaskDataFeeder does not support dictionaries at the moment.')
842
843    # pylint: disable=invalid-name,super-init-not-called
844    import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
845    # TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
846    self._x = x
847    self._y = y
848    # save column names
849    self._x_columns = list(x.columns)
850    if isinstance(y.columns[0], str):
851      self._y_columns = list(y.columns)
852    else:
853      # deal with cases where two DFs have overlapped default numeric colnames
854      self._y_columns = len(self._x_columns) + 1
855      self._y = self._y.rename(columns={y.columns[0]: self._y_columns})
856
857    # TODO(terrytangyuan): deal with unsupervised cases
858    # combine into a data frame
859    self.df = dd.multi.concat([self._x, self._y], axis=1)
860    self.n_classes = n_classes
861
862    x_count = x.count().compute()[0]
863    x_shape = (x_count, len(self._x.columns))
864    y_shape = (x_count, len(self._y.columns))
865    # TODO(terrytangyuan): Add support for shuffle and epochs.
866    self._shuffle = shuffle
867    self.epochs = epochs
868    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
869        x_shape, y_shape, n_classes, batch_size)
870    self.sample_fraction = self._batch_size / float(x_count)
871    self._input_dtype = _check_dtype(self._x.dtypes[0])
872    self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
873    if random_state is None:
874      self.random_state = 66
875    else:
876      self.random_state = random_state
877
878  def get_feed_params(self):
879    """Function returns a `dict` with data feed params while training.
880
881    Returns:
882      A `dict` with data feed params while training.
883    """
884    return {'batch_size': self._batch_size}
885
886  def get_feed_dict_fn(self, input_placeholder, output_placeholder):
887    """Returns a function, that will sample data and provide it to placeholders.
888
889    Args:
890      input_placeholder: tf.placeholder for input features mini batch.
891      output_placeholder: tf.placeholder for output labels.
892
893    Returns:
894      A function that when called samples a random subset of batch size
895      from x and y.
896    """
897
898    def _feed_dict_fn():
899      """Samples data and provides it to placeholders."""
900      # TODO(ipolosukhin): option for with/without replacement (dev version of
901      # dask)
902      sample = self.df.random_split(
903          [self.sample_fraction, 1 - self.sample_fraction],
904          random_state=self.random_state)
905      inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
906      out = extract_pandas_matrix(sample[0][self._y_columns].compute())
907      # convert to correct dtype
908      inp = np.array(inp, dtype=self._input_dtype)
909      # one-hot encode out for each class for cross entropy loss
910      if HAS_PANDAS:
911        import pandas as pd  # pylint: disable=g-import-not-at-top
912        if not isinstance(out, pd.Series):
913          out = out.flatten()
914      out_max = self._y.max().compute().values[0]
915      encoded_out = np.zeros((out.size, out_max + 1), dtype=self._output_dtype)
916      encoded_out[np.arange(out.size), out] = 1
917      return {input_placeholder.name: inp, output_placeholder.name: encoded_out}
918
919    return _feed_dict_fn
920