• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for trackable_utils."""
16
17from tensorflow.python.eager import test
18from tensorflow.python.trackable import trackable_utils
19
20
21class TrackableUtilsTest(test.TestCase):
22
23  def test_order_by_dependency(self):
24    """Tests order_by_dependency correctness."""
25
26    # Visual graph (vertical lines point down, so 1 depends on 2):
27    #    1
28    #  /   \
29    # 2 --> 3 <-- 4
30    #       |
31    #       5
32    # One possible order: [5, 3, 4, 2, 1]
33    dependencies = {1: [2, 3], 2: [3], 3: [5], 4: [3], 5: []}
34
35    sorted_arr = list(trackable_utils.order_by_dependency(dependencies))
36    indices = {x: sorted_arr.index(x) for x in range(1, 6)}
37    self.assertEqual(indices[5], 0)
38    self.assertEqual(indices[3], 1)
39    self.assertGreater(indices[1], indices[2])  # 2 must appear before 1
40
41  def test_order_by_no_dependency(self):
42    sorted_arr = list(trackable_utils.order_by_dependency(
43        {x: [] for x in range(15)}))
44    self.assertEqual(set(sorted_arr), set(range(15)))
45
46  def test_order_by_dependency_invalid_map(self):
47    with self.assertRaisesRegex(
48        ValueError, "Found values in the dependency map which are not keys"):
49      trackable_utils.order_by_dependency({1: [2]})
50
51
52if __name__ == "__main__":
53  test.main()
54
55