• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests that the TensorFlow parts of the known anomaly example run."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.timeseries.examples import known_anomaly
22
23from tensorflow.python.platform import test
24
25
26class KnownAnomalyExampleTest(test.TestCase):
27
28  def test_shapes_and_variance_structural_ar(self):
29    (times, observed, all_times, mean, upper_limit, lower_limit,
30     anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
31         train_steps=1, estimator_fn=known_anomaly.autoregressive_estimator)
32    self.assertAllEqual(
33        anomaly_locations,
34        [25, 50, 75, 100, 125, 150, 175, 249])
35    self.assertAllEqual(all_times.shape, mean.shape)
36    self.assertAllEqual(all_times.shape, upper_limit.shape)
37    self.assertAllEqual(all_times.shape, lower_limit.shape)
38    self.assertAllEqual(times.shape, observed.shape)
39
40  def test_shapes_and_variance_structural_ssm(self):
41    (times, observed, all_times, mean, upper_limit, lower_limit,
42     anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
43         train_steps=50, estimator_fn=known_anomaly.state_space_estimator)
44    self.assertAllEqual(
45        anomaly_locations,
46        [25, 50, 75, 100, 125, 150, 175, 249])
47    self.assertAllEqual([200], times.shape)
48    self.assertAllEqual([200], observed.shape)
49    self.assertAllEqual([300], all_times.shape)
50    self.assertAllEqual([300], mean.shape)
51    self.assertAllEqual([300], upper_limit.shape)
52    self.assertAllEqual([300], lower_limit.shape)
53    # Check that initial predictions are relatively confident.
54    self.assertLess(upper_limit[210] - lower_limit[210],
55                    3.0 * (upper_limit[200] - lower_limit[200]))
56    # Check that post-changepoint predictions are less confident
57    self.assertGreater(upper_limit[290] - lower_limit[290],
58                       3.0 * (upper_limit[240] - lower_limit[240]))
59
60if __name__ == "__main__":
61  test.main()
62