• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""print_model_analysis test."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import init_ops
20from tensorflow.python.ops import nn_ops
21from tensorflow.python.ops import variable_scope
22from tensorflow.python.platform import test
23
24
25# pylint: disable=bad-whitespace
26# pylint: disable=bad-continuation
27TEST_OPTIONS = {
28    'max_depth': 10000,
29    'min_bytes': 0,
30    'min_micros': 0,
31    'min_params': 0,
32    'min_float_ops': 0,
33    'order_by': 'name',
34    'account_type_regexes': ['.*'],
35    'start_name_regexes': ['.*'],
36    'trim_name_regexes': [],
37    'show_name_regexes': ['.*'],
38    'hide_name_regexes': [],
39    'account_displayed_op_only': True,
40    'select': ['params'],
41    'output': 'stdout',
42}
43
44# pylint: enable=bad-whitespace
45# pylint: enable=bad-continuation
46
47
48class PrintModelAnalysisTest(test.TestCase):
49
50  def _BuildSmallModel(self):
51    image = array_ops.zeros([2, 6, 6, 3])
52    kernel = variable_scope.get_variable(
53        'DW', [6, 6, 3, 6],
54        dtypes.float32,
55        initializer=init_ops.random_normal_initializer(stddev=0.001))
56    x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
57    return x
58
59
60if __name__ == '__main__':
61  test.main()
62