• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Computes a header file to be used with SELECTIVE_REGISTRATION.
16
17See the executable wrapper, print_selective_registration_header.py, for more
18information.
19"""
20
21import json
22import os
23import sys
24
25from google.protobuf import text_format
26from tensorflow.core.framework import graph_pb2
27from tensorflow.python.platform import gfile
28from tensorflow.python.platform import tf_logging
29from tensorflow.python.util import _pywrap_kernel_registry
30
31# Usually, we use each graph node to induce registration of an op and
32# corresponding kernel; nodes without a corresponding kernel (perhaps due to
33# attr types) generate a warning but are otherwise ignored. Ops in this set are
34# registered even if there's no corresponding kernel.
35OPS_WITHOUT_KERNEL_ALLOWLIST = frozenset([
36    # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
37    # core/common_runtime/accumulate_n_optimizer.cc.
38    'AccumulateNV2'
39])
40FLEX_PREFIX = b'Flex'
41FLEX_PREFIX_LENGTH = len(FLEX_PREFIX)
42
43
44def _get_ops_from_ops_list(input_file):
45  """Gets the ops and kernels needed from the ops list file."""
46  ops = set()
47  ops_list_str = gfile.GFile(input_file, 'r').read()
48  if not ops_list_str:
49    raise Exception('Input file should not be empty')
50  ops_list = json.loads(ops_list_str)
51  for op, kernel in ops_list:
52    op_and_kernel = (op, kernel if kernel else None)
53    ops.add(op_and_kernel)
54  return ops
55
56
57def _get_ops_from_graphdef(graph_def):
58  """Gets the ops and kernels needed from the tensorflow model."""
59  ops = set()
60  ops.update(_get_ops_from_nodedefs(graph_def.node))
61
62  for function in graph_def.library.function:
63    ops.update(_get_ops_from_nodedefs(function.node_def))
64  return ops
65
66
67def _get_ops_from_nodedefs(node_defs):
68  """Gets the ops and kernels needed from the list of NodeDef."""
69  ops = set()
70  for node_def in node_defs:
71    if not node_def.device:
72      node_def.device = '/cpu:0'
73    kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
74        node_def.SerializeToString())
75    op = str(node_def.op)
76    if kernel_class or op in OPS_WITHOUT_KERNEL_ALLOWLIST:
77      op_and_kernel = (op, str(kernel_class.decode('utf-8'))
78                       if kernel_class else None)
79      ops.add(op_and_kernel)
80    else:
81      tf_logging.warning('Warning: no kernel found for op %s', op)
82  return ops
83
84
85def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
86  """Gets the ops and kernels needed from the model files."""
87  ops = set()
88
89  for proto_file in proto_files:
90    tf_logging.info('Loading proto file %s', proto_file)
91    # Load ops list file.
92    if proto_fileformat == 'ops_list':
93      ops = ops.union(_get_ops_from_ops_list(proto_file))
94      continue
95
96    # Load GraphDef.
97    file_data = gfile.GFile(proto_file, 'rb').read()
98    if proto_fileformat == 'rawproto':
99      graph_def = graph_pb2.GraphDef.FromString(file_data)
100    else:
101      assert proto_fileformat == 'textproto'
102      graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())
103    ops = ops.union(_get_ops_from_graphdef(graph_def))
104
105  # Add default ops.
106  if default_ops_str and default_ops_str != 'all':
107    for s in default_ops_str.split(','):
108      op, kernel = s.split(':')
109      op_and_kernel = (op, kernel)
110      if op_and_kernel not in ops:
111        ops.add(op_and_kernel)
112
113  return list(sorted(ops))
114
115
116def get_header_from_ops_and_kernels(ops_and_kernels,
117                                    include_all_ops_and_kernels):
118  """Returns a header for use with tensorflow SELECTIVE_REGISTRATION.
119
120  Args:
121    ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include.
122    include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op
123      kernels are included.
124
125  Returns:
126    the string of the header that should be written as ops_to_register.h.
127  """
128  ops = set(op for op, _ in ops_and_kernels)
129  result_list = []
130
131  def append(s):
132    result_list.append(s)
133
134  _, script_name = os.path.split(sys.argv[0])
135  append('// This file was autogenerated by %s' % script_name)
136  append('#ifndef OPS_TO_REGISTER')
137  append('#define OPS_TO_REGISTER')
138
139  if include_all_ops_and_kernels:
140    append('#define SHOULD_REGISTER_OP(op) true')
141    append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
142    append('#define SHOULD_REGISTER_OP_GRADIENT true')
143  else:
144    line = """
145    namespace {
146      constexpr const char* skip(const char* x) {
147        return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
148      }
149
150      constexpr bool isequal(const char* x, const char* y) {
151        return (*skip(x) && *skip(y))
152                   ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
153                   : (!*skip(x) && !*skip(y));
154      }
155
156      template<int N>
157      struct find_in {
158        static constexpr bool f(const char* x, const char* const y[N]) {
159          return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
160        }
161      };
162
163      template<>
164      struct find_in<0> {
165        static constexpr bool f(const char* x, const char* const y[]) {
166          return false;
167        }
168      };
169    }  // end namespace
170    """
171    line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
172    for _, kernel_class in ops_and_kernels:
173      if kernel_class is None:
174        continue
175      line += '"%s",\n' % kernel_class
176    line += '};'
177    append(line)
178    append('#define SHOULD_REGISTER_OP_KERNEL(clz) '
179           '(find_in<sizeof(kNecessaryOpKernelClasses) '
180           '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, '
181           'kNecessaryOpKernelClasses))')
182    append('')
183
184    append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
185    append('  return false')
186    for op in sorted(ops):
187      append('     || isequal(op, "%s")' % op)
188    append('  ;')
189    append('}')
190    append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
191    append('')
192
193    append('#define SHOULD_REGISTER_OP_GRADIENT ' +
194           ('true' if 'SymbolicGradient' in ops else 'false'))
195
196  append('#endif')
197  return '\n'.join(result_list)
198
199
200def get_header(graphs,
201               proto_fileformat='rawproto',
202               default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'):
203  """Computes a header for use with tensorflow SELECTIVE_REGISTRATION.
204
205  Args:
206    graphs: a list of paths to GraphDef files to include.
207    proto_fileformat: optional format of proto file, either 'textproto',
208      'rawproto' (default) or ops_list. The ops_list is the file contain the
209      list of ops in JSON format, Ex: "[["Transpose", "TransposeCpuOp"]]".
210    default_ops: optional comma-separated string of operator:kernel pairs to
211      always include implementation for. Pass 'all' to have all operators and
212      kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
213
214  Returns:
215    the string of the header that should be written as ops_to_register.h.
216  """
217  ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops)
218  if not ops_and_kernels:
219    print('Error reading graph!')
220    return 1
221
222  return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all')
223