1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.histogram_ops.""" 16 17import numpy as np 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import errors 21from tensorflow.python.framework import test_util 22from tensorflow.python.framework import constant_op 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import histogram_ops 25from tensorflow.python.platform import test 26 27 28class BinValuesFixedWidth(test.TestCase): 29 30 def test_empty_input_gives_all_zero_counts(self): 31 # Bins will be: 32 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 33 value_range = [0.0, 5.0] 34 values = [] 35 expected_bins = [] 36 with self.cached_session(): 37 bins = histogram_ops.histogram_fixed_width_bins( 38 values, value_range, nbins=5) 39 self.assertEqual(dtypes.int32, bins.dtype) 40 self.assertAllClose(expected_bins, self.evaluate(bins)) 41 42 def test_1d_values_int32_output(self): 43 # Bins will be: 44 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 45 value_range = [0.0, 5.0] 46 values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] 47 expected_bins = [0, 0, 1, 2, 4, 4] 48 with self.cached_session(): 49 bins = histogram_ops.histogram_fixed_width_bins( 50 values, value_range, nbins=5, dtype=dtypes.int64) 51 self.assertEqual(dtypes.int32, bins.dtype) 52 self.assertAllClose(expected_bins, self.evaluate(bins)) 53 54 def test_1d_float64_values_int32_output(self): 55 # Bins will be: 56 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 57 value_range = np.float64([0.0, 5.0]) 58 values = np.float64([-1.0, 0.0, 1.5, 2.0, 5.0, 15]) 59 expected_bins = [0, 0, 1, 2, 4, 4] 60 with self.cached_session(): 61 bins = histogram_ops.histogram_fixed_width_bins( 62 values, value_range, nbins=5) 63 self.assertEqual(dtypes.int32, bins.dtype) 64 self.assertAllClose(expected_bins, self.evaluate(bins)) 65 66 def test_2d_values(self): 67 # Bins will be: 68 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 69 value_range = [0.0, 5.0] 70 values = constant_op.constant( 71 [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], shape=(2, 3)) 72 expected_bins = [[0, 0, 1], [2, 4, 4]] 73 with self.cached_session(): 74 bins = histogram_ops.histogram_fixed_width_bins( 75 values, value_range, nbins=5) 76 self.assertEqual(dtypes.int32, bins.dtype) 77 self.assertAllClose(expected_bins, self.evaluate(bins)) 78 79 def test_negative_nbins(self): 80 value_range = [0.0, 5.0] 81 values = [] 82 with self.assertRaisesRegex((errors.InvalidArgumentError, ValueError), 83 "must > 0"): 84 with self.session(): 85 bins = histogram_ops.histogram_fixed_width_bins( 86 values, value_range, nbins=-1) 87 self.evaluate(bins) 88 89 90 91class HistogramFixedWidthTest(test.TestCase): 92 93 def setUp(self): 94 self.rng = np.random.RandomState(0) 95 96 @test_util.run_deprecated_v1 97 def test_with_invalid_value_range(self): 98 values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] 99 with self.assertRaisesRegex(ValueError, 100 "Shape must be rank 1 but is rank 0"): 101 histogram_ops.histogram_fixed_width(values, 1.0) 102 with self.assertRaisesRegex(ValueError, "Dimension must be 2 but is 3"): 103 histogram_ops.histogram_fixed_width(values, [1.0, 2.0, 3.0]) 104 105 @test_util.run_deprecated_v1 106 def test_with_invalid_nbins(self): 107 values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] 108 with self.assertRaisesRegex(ValueError, 109 "Shape must be rank 0 but is rank 1"): 110 histogram_ops.histogram_fixed_width(values, [1.0, 5.0], nbins=[1, 2]) 111 with self.assertRaisesRegex(ValueError, "Requires nbins > 0"): 112 histogram_ops.histogram_fixed_width(values, [1.0, 5.0], nbins=-5) 113 114 def test_empty_input_gives_all_zero_counts(self): 115 # Bins will be: 116 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 117 value_range = [0.0, 5.0] 118 values = [] 119 expected_bin_counts = [0, 0, 0, 0, 0] 120 with self.session(): 121 hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5) 122 self.assertEqual(dtypes.int32, hist.dtype) 123 self.assertAllClose(expected_bin_counts, self.evaluate(hist)) 124 125 def test_1d_values_int64_output(self): 126 # Bins will be: 127 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 128 value_range = [0.0, 5.0] 129 values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] 130 expected_bin_counts = [2, 1, 1, 0, 2] 131 with self.session(): 132 hist = histogram_ops.histogram_fixed_width( 133 values, value_range, nbins=5, dtype=dtypes.int64) 134 self.assertEqual(dtypes.int64, hist.dtype) 135 self.assertAllClose(expected_bin_counts, self.evaluate(hist)) 136 137 def test_1d_float64_values(self): 138 # Bins will be: 139 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 140 value_range = np.float64([0.0, 5.0]) 141 values = np.float64([-1.0, 0.0, 1.5, 2.0, 5.0, 15]) 142 expected_bin_counts = [2, 1, 1, 0, 2] 143 with self.session(): 144 hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5) 145 self.assertEqual(dtypes.int32, hist.dtype) 146 self.assertAllClose(expected_bin_counts, self.evaluate(hist)) 147 148 def test_2d_values(self): 149 # Bins will be: 150 # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) 151 value_range = [0.0, 5.0] 152 values = [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]] 153 expected_bin_counts = [2, 1, 1, 0, 2] 154 with self.session(): 155 hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5) 156 self.assertEqual(dtypes.int32, hist.dtype) 157 self.assertAllClose(expected_bin_counts, self.evaluate(hist)) 158 159 @test_util.run_deprecated_v1 160 def test_shape_inference(self): 161 value_range = [0.0, 5.0] 162 values = [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]] 163 expected_bin_counts = [2, 1, 1, 0, 2] 164 placeholder = array_ops.placeholder(dtypes.int32) 165 with self.session(): 166 hist = histogram_ops.histogram_fixed_width(values, value_range, nbins=5) 167 self.assertAllEqual(hist.shape.as_list(), (5,)) 168 self.assertEqual(dtypes.int32, hist.dtype) 169 self.assertAllClose(expected_bin_counts, self.evaluate(hist)) 170 171 hist = histogram_ops.histogram_fixed_width( 172 values, value_range, nbins=placeholder) 173 self.assertEqual(hist.shape.ndims, 1) 174 self.assertIs(hist.shape.dims[0].value, None) 175 self.assertEqual(dtypes.int32, hist.dtype) 176 self.assertAllClose(expected_bin_counts, hist.eval({placeholder: 5})) 177 178 179if __name__ == '__main__': 180 test.main() 181