• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for learn.utils.gc."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import re
23
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.contrib.learn.python.learn.utils import gc
27from tensorflow.python.framework import test_util
28from tensorflow.python.platform import gfile
29from tensorflow.python.platform import test
30from tensorflow.python.util import compat
31
32
33def _create_parser(base_dir):
34  # create a simple parser that pulls the export_version from the directory.
35  def parser(path):
36    # Modify the path object for RegEx match for Windows Paths
37    if os.name == "nt":
38      match = re.match(
39          r"^" + compat.as_str_any(base_dir).replace("\\", "/") + r"/(\d+)$",
40          compat.as_str_any(path.path).replace("\\", "/"))
41    else:
42      match = re.match(r"^" + compat.as_str_any(base_dir) + r"/(\d+)$",
43                       compat.as_str_any(path.path))
44    if not match:
45      return None
46    return path._replace(export_version=int(match.group(1)))
47
48  return parser
49
50
51class GcTest(test_util.TensorFlowTestCase):
52
53  def testLargestExportVersions(self):
54    paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
55    newest = gc.largest_export_versions(2)
56    n = newest(paths)
57    self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
58
59  def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
60    paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
61    newest = gc.largest_export_versions(2)
62    n = newest(paths)
63    self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
64
65  def testModExportVersion(self):
66    paths = [
67        gc.Path("/foo", 4),
68        gc.Path("/foo", 5),
69        gc.Path("/foo", 6),
70        gc.Path("/foo", 9)
71    ]
72    mod = gc.mod_export_version(2)
73    self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
74    mod = gc.mod_export_version(3)
75    self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
76
77  def testOneOfEveryNExportVersions(self):
78    paths = [
79        gc.Path("/foo", 0),
80        gc.Path("/foo", 1),
81        gc.Path("/foo", 3),
82        gc.Path("/foo", 5),
83        gc.Path("/foo", 6),
84        gc.Path("/foo", 7),
85        gc.Path("/foo", 8),
86        gc.Path("/foo", 33)
87    ]
88    one_of = gc.one_of_every_n_export_versions(3)
89    self.assertEqual(
90        one_of(paths), [
91            gc.Path("/foo", 3),
92            gc.Path("/foo", 6),
93            gc.Path("/foo", 8),
94            gc.Path("/foo", 33)
95        ])
96
97  def testOneOfEveryNExportVersionsZero(self):
98    # Zero is a special case since it gets rolled into the first interval.
99    # Test that here.
100    paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
101    one_of = gc.one_of_every_n_export_versions(3)
102    self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
103
104  def testUnion(self):
105    paths = []
106    for i in xrange(10):
107      paths.append(gc.Path("/foo", i))
108    f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
109    self.assertEqual(
110        f(paths), [
111            gc.Path("/foo", 0),
112            gc.Path("/foo", 3),
113            gc.Path("/foo", 6),
114            gc.Path("/foo", 7),
115            gc.Path("/foo", 8),
116            gc.Path("/foo", 9)
117        ])
118
119  def testNegation(self):
120    paths = [
121        gc.Path("/foo", 4),
122        gc.Path("/foo", 5),
123        gc.Path("/foo", 6),
124        gc.Path("/foo", 9)
125    ]
126    mod = gc.negation(gc.mod_export_version(2))
127    self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
128    mod = gc.negation(gc.mod_export_version(3))
129    self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
130
131  def testPathsWithParse(self):
132    base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
133    self.assertFalse(gfile.Exists(base_dir))
134    for p in xrange(3):
135      gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
136    # add a base_directory to ignore
137    gfile.MakeDirs(os.path.join(base_dir, "ignore"))
138
139    self.assertEqual(
140        gc.get_paths(base_dir, _create_parser(base_dir)), [
141            gc.Path(os.path.join(base_dir, "0"), 0),
142            gc.Path(os.path.join(base_dir, "1"), 1),
143            gc.Path(os.path.join(base_dir, "2"), 2)
144        ])
145
146  def testMixedStrTypes(self):
147    temp_dir = compat.as_bytes(test.get_temp_dir())
148
149    for sub_dir in ["str", b"bytes", u"unicode"]:
150      base_dir = os.path.join(
151          (temp_dir
152           if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir)
153      self.assertFalse(gfile.Exists(base_dir))
154      gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
155      gc.get_paths(base_dir, _create_parser(base_dir))
156
157
158if __name__ == "__main__":
159  test.main()
160