1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Debugger Wrapper Session Consisting of a Local Curses-based CLI.""" 16import os 17import tempfile 18 19from tensorflow.python.client import session 20from tensorflow.python.debug.wrappers import dumping_wrapper 21from tensorflow.python.debug.wrappers import hooks 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import state_ops 26from tensorflow.python.ops import variables 27from tensorflow.python.platform import googletest 28from tensorflow.python.training import monitored_session 29 30 31@test_util.run_v1_only("Sessions are not available in TF 2.x") 32class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase): 33 34 @classmethod 35 def setUpClass(cls): 36 # For efficient testing, set the disk usage bytes limit to a small 37 # number (10). 38 os.environ["TFDBG_DISK_BYTES_LIMIT"] = "10" 39 40 def setUp(self): 41 self.session_root = tempfile.mkdtemp() 42 43 self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v") 44 self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta") 45 self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta") 46 self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v") 47 self.dec_v = state_ops.assign_add(self.v, self.eta, name="dec_v") 48 49 self.sess = session.Session() 50 self.sess.run(self.v.initializer) 51 52 def testWrapperSessionNotExceedingLimit(self): 53 def _watch_fn(fetches, feeds): 54 del fetches, feeds 55 return "DebugIdentity", r"(.*delta.*|.*inc_v.*)", r".*" 56 sess = dumping_wrapper.DumpingDebugWrapperSession( 57 self.sess, session_root=self.session_root, 58 watch_fn=_watch_fn, log_usage=False) 59 sess.run(self.inc_v) 60 61 def testWrapperSessionExceedingLimit(self): 62 def _watch_fn(fetches, feeds): 63 del fetches, feeds 64 return "DebugIdentity", r".*delta.*", r".*" 65 sess = dumping_wrapper.DumpingDebugWrapperSession( 66 self.sess, session_root=self.session_root, 67 watch_fn=_watch_fn, log_usage=False) 68 # Due to the watch function, each run should dump only 1 tensor, 69 # which has a size of 4 bytes, which corresponds to the dumped 'delta:0' 70 # tensor of scalar shape and float32 dtype. 71 # 1st run should pass, after which the disk usage is at 4 bytes. 72 sess.run(self.inc_v) 73 # 2nd run should also pass, after which 8 bytes are used. 74 sess.run(self.inc_v) 75 # 3rd run should fail, because the total byte count (12) exceeds the 76 # limit (10) 77 with self.assertRaises(ValueError): 78 sess.run(self.inc_v) 79 80 def testHookNotExceedingLimit(self): 81 def _watch_fn(fetches, feeds): 82 del fetches, feeds 83 return "DebugIdentity", r".*delta.*", r".*" 84 dumping_hook = hooks.DumpingDebugHook( 85 self.session_root, watch_fn=_watch_fn, log_usage=False) 86 mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) 87 mon_sess.run(self.inc_v) 88 89 def testHookExceedingLimit(self): 90 def _watch_fn(fetches, feeds): 91 del fetches, feeds 92 return "DebugIdentity", r".*delta.*", r".*" 93 dumping_hook = hooks.DumpingDebugHook( 94 self.session_root, watch_fn=_watch_fn, log_usage=False) 95 mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook]) 96 # Like in `testWrapperSessionExceedingLimit`, the first two calls 97 # should be within the byte limit, but the third one should error 98 # out due to exceeding the limit. 99 mon_sess.run(self.inc_v) 100 mon_sess.run(self.inc_v) 101 with self.assertRaises(ValueError): 102 mon_sess.run(self.inc_v) 103 104 105if __name__ == "__main__": 106 googletest.main() 107