1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Methods to allow dask.DataFrame (deprecated). 17 18This module and all its submodules are deprecated. See 19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 20for migration instructions. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import numpy as np 28 29from tensorflow.python.util.deprecation import deprecated 30 31try: 32 # pylint: disable=g-import-not-at-top 33 import dask.dataframe as dd 34 allowed_classes = (dd.Series, dd.DataFrame) 35 HAS_DASK = True 36except ImportError: 37 HAS_DASK = False 38 39 40def _add_to_index(df, start): 41 """New dask.dataframe with values added to index of each subdataframe.""" 42 df = df.copy() 43 df.index += start 44 return df 45 46 47def _get_divisions(df): 48 """Number of rows in each sub-dataframe.""" 49 lengths = df.map_partitions(len).compute() 50 divisions = np.cumsum(lengths).tolist() 51 divisions.insert(0, 0) 52 return divisions 53 54 55def _construct_dask_df_with_divisions(df): 56 """Construct the new task graph and make a new dask.dataframe around it.""" 57 divisions = _get_divisions(df) 58 # pylint: disable=protected-access 59 name = 'csv-index' + df._name 60 dsk = {(name, i): (_add_to_index, (df._name, i), divisions[i]) 61 for i in range(df.npartitions)} 62 # pylint: enable=protected-access 63 from toolz import merge # pylint: disable=g-import-not-at-top 64 if isinstance(df, dd.DataFrame): 65 return dd.DataFrame(merge(dsk, df.dask), name, df.columns, divisions) 66 elif isinstance(df, dd.Series): 67 return dd.Series(merge(dsk, df.dask), name, df.name, divisions) 68 69 70@deprecated(None, 'Please feed input to tf.data to support dask.') 71def extract_dask_data(data): 72 """Extract data from dask.Series or dask.DataFrame for predictors. 73 74 Given a distributed dask.DataFrame or dask.Series containing columns or names 75 for one or more predictors, this operation returns a single dask.DataFrame or 76 dask.Series that can be iterated over. 77 78 Args: 79 data: A distributed dask.DataFrame or dask.Series. 80 81 Returns: 82 A dask.DataFrame or dask.Series that can be iterated over. 83 If the supplied argument is neither a dask.DataFrame nor a dask.Series this 84 operation returns it without modification. 85 """ 86 if isinstance(data, allowed_classes): 87 return _construct_dask_df_with_divisions(data) 88 else: 89 return data 90 91 92@deprecated(None, 'Please feed input to tf.data to support dask.') 93def extract_dask_labels(labels): 94 """Extract data from dask.Series or dask.DataFrame for labels. 95 96 Given a distributed dask.DataFrame or dask.Series containing exactly one 97 column or name, this operation returns a single dask.DataFrame or dask.Series 98 that can be iterated over. 99 100 Args: 101 labels: A distributed dask.DataFrame or dask.Series with exactly one 102 column or name. 103 104 Returns: 105 A dask.DataFrame or dask.Series that can be iterated over. 106 If the supplied argument is neither a dask.DataFrame nor a dask.Series this 107 operation returns it without modification. 108 109 Raises: 110 ValueError: If the supplied dask.DataFrame contains more than one 111 column or the supplied dask.Series contains more than 112 one name. 113 """ 114 if isinstance(labels, dd.DataFrame): 115 ncol = labels.columns 116 elif isinstance(labels, dd.Series): 117 ncol = labels.name 118 if isinstance(labels, allowed_classes): 119 if len(ncol) > 1: 120 raise ValueError('Only one column for labels is allowed.') 121 return _construct_dask_df_with_divisions(labels) 122 else: 123 return labels 124