• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""print_model_analysis test."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import init_ops
24from tensorflow.python.ops import nn_ops
25from tensorflow.python.ops import variable_scope
26from tensorflow.python.platform import test
27
28
29# pylint: disable=bad-whitespace
30# pylint: disable=bad-continuation
31TEST_OPTIONS = {
32    'max_depth': 10000,
33    'min_bytes': 0,
34    'min_micros': 0,
35    'min_params': 0,
36    'min_float_ops': 0,
37    'order_by': 'name',
38    'account_type_regexes': ['.*'],
39    'start_name_regexes': ['.*'],
40    'trim_name_regexes': [],
41    'show_name_regexes': ['.*'],
42    'hide_name_regexes': [],
43    'account_displayed_op_only': True,
44    'select': ['params'],
45    'output': 'stdout',
46}
47
48# pylint: enable=bad-whitespace
49# pylint: enable=bad-continuation
50
51
52class PrintModelAnalysisTest(test.TestCase):
53
54  def _BuildSmallModel(self):
55    image = array_ops.zeros([2, 6, 6, 3])
56    kernel = variable_scope.get_variable(
57        'DW', [6, 6, 3, 6],
58        dtypes.float32,
59        initializer=init_ops.random_normal_initializer(stddev=0.001))
60    x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
61    return x
62
63
64if __name__ == '__main__':
65  test.main()
66