1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests that the TensorFlow parts of the known anomaly example run.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.timeseries.examples import known_anomaly 22 23from tensorflow.python.platform import test 24 25 26class KnownAnomalyExampleTest(test.TestCase): 27 28 def test_shapes_and_variance_structural_ar(self): 29 (times, observed, all_times, mean, upper_limit, lower_limit, 30 anomaly_locations) = known_anomaly.train_and_evaluate_exogenous( 31 train_steps=1, estimator_fn=known_anomaly.autoregressive_estimator) 32 self.assertAllEqual( 33 anomaly_locations, 34 [25, 50, 75, 100, 125, 150, 175, 249]) 35 self.assertAllEqual(all_times.shape, mean.shape) 36 self.assertAllEqual(all_times.shape, upper_limit.shape) 37 self.assertAllEqual(all_times.shape, lower_limit.shape) 38 self.assertAllEqual(times.shape, observed.shape) 39 40 def test_shapes_and_variance_structural_ssm(self): 41 (times, observed, all_times, mean, upper_limit, lower_limit, 42 anomaly_locations) = known_anomaly.train_and_evaluate_exogenous( 43 train_steps=50, estimator_fn=known_anomaly.state_space_estimator) 44 self.assertAllEqual( 45 anomaly_locations, 46 [25, 50, 75, 100, 125, 150, 175, 249]) 47 self.assertAllEqual([200], times.shape) 48 self.assertAllEqual([200], observed.shape) 49 self.assertAllEqual([300], all_times.shape) 50 self.assertAllEqual([300], mean.shape) 51 self.assertAllEqual([300], upper_limit.shape) 52 self.assertAllEqual([300], lower_limit.shape) 53 # Check that initial predictions are relatively confident. 54 self.assertLess(upper_limit[210] - lower_limit[210], 55 3.0 * (upper_limit[200] - lower_limit[200])) 56 # Check that post-changepoint predictions are less confident 57 self.assertGreater(upper_limit[290] - lower_limit[290], 58 3.0 * (upper_limit[240] - lower_limit[240])) 59 60if __name__ == "__main__": 61 test.main() 62