1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for trackable_utils.""" 16 17from tensorflow.python.eager import test 18from tensorflow.python.trackable import trackable_utils 19 20 21class TrackableUtilsTest(test.TestCase): 22 23 def test_order_by_dependency(self): 24 """Tests order_by_dependency correctness.""" 25 26 # Visual graph (vertical lines point down, so 1 depends on 2): 27 # 1 28 # / \ 29 # 2 --> 3 <-- 4 30 # | 31 # 5 32 # One possible order: [5, 3, 4, 2, 1] 33 dependencies = {1: [2, 3], 2: [3], 3: [5], 4: [3], 5: []} 34 35 sorted_arr = list(trackable_utils.order_by_dependency(dependencies)) 36 indices = {x: sorted_arr.index(x) for x in range(1, 6)} 37 self.assertEqual(indices[5], 0) 38 self.assertEqual(indices[3], 1) 39 self.assertGreater(indices[1], indices[2]) # 2 must appear before 1 40 41 def test_order_by_no_dependency(self): 42 sorted_arr = list(trackable_utils.order_by_dependency( 43 {x: [] for x in range(15)})) 44 self.assertEqual(set(sorted_arr), set(range(15))) 45 46 def test_order_by_dependency_invalid_map(self): 47 with self.assertRaisesRegex( 48 ValueError, "Found values in the dependency map which are not keys"): 49 trackable_utils.order_by_dependency({1: [2]}) 50 51 52if __name__ == "__main__": 53 test.main() 54 55