• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test data utilities (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import numpy as np
27from tensorflow.contrib.learn.python.learn.datasets import base
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30
31
32def get_quantile_based_buckets(feature_values, num_buckets):
33  quantiles = np.percentile(
34      np.array(feature_values),
35      ([100 * (i + 1.) / (num_buckets + 1.) for i in range(num_buckets)]))
36  return list(quantiles)
37
38
39def prepare_iris_data_for_logistic_regression():
40  # Converts iris data to a logistic regression problem.
41  iris = base.load_iris()
42  ids = np.where((iris.target == 0) | (iris.target == 1))
43  return base.Dataset(data=iris.data[ids], target=iris.target[ids])
44
45
46def iris_input_multiclass_fn():
47  iris = base.load_iris()
48  return {
49      'feature': constant_op.constant(
50          iris.data, dtype=dtypes.float32)
51  }, constant_op.constant(
52      iris.target, shape=(150, 1), dtype=dtypes.int32)
53
54
55def iris_input_logistic_fn():
56  iris = prepare_iris_data_for_logistic_regression()
57  return {
58      'feature': constant_op.constant(
59          iris.data, dtype=dtypes.float32)
60  }, constant_op.constant(
61      iris.target, shape=(100, 1), dtype=dtypes.int32)
62