1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Computes a header file to be used with SELECTIVE_REGISTRATION. 16 17See the executable wrapper, print_selective_registration_header.py, for more 18information. 19""" 20 21import json 22import os 23import sys 24 25from google.protobuf import text_format 26from tensorflow.core.framework import graph_pb2 27from tensorflow.python.platform import gfile 28from tensorflow.python.platform import tf_logging 29from tensorflow.python.util import _pywrap_kernel_registry 30 31# Usually, we use each graph node to induce registration of an op and 32# corresponding kernel; nodes without a corresponding kernel (perhaps due to 33# attr types) generate a warning but are otherwise ignored. Ops in this set are 34# registered even if there's no corresponding kernel. 35OPS_WITHOUT_KERNEL_ALLOWLIST = frozenset([ 36 # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see 37 # core/common_runtime/accumulate_n_optimizer.cc. 38 'AccumulateNV2' 39]) 40FLEX_PREFIX = b'Flex' 41FLEX_PREFIX_LENGTH = len(FLEX_PREFIX) 42 43 44def _get_ops_from_ops_list(input_file): 45 """Gets the ops and kernels needed from the ops list file.""" 46 ops = set() 47 ops_list_str = gfile.GFile(input_file, 'r').read() 48 if not ops_list_str: 49 raise Exception('Input file should not be empty') 50 ops_list = json.loads(ops_list_str) 51 for op, kernel in ops_list: 52 op_and_kernel = (op, kernel if kernel else None) 53 ops.add(op_and_kernel) 54 return ops 55 56 57def _get_ops_from_graphdef(graph_def): 58 """Gets the ops and kernels needed from the tensorflow model.""" 59 ops = set() 60 ops.update(_get_ops_from_nodedefs(graph_def.node)) 61 62 for function in graph_def.library.function: 63 ops.update(_get_ops_from_nodedefs(function.node_def)) 64 return ops 65 66 67def _get_ops_from_nodedefs(node_defs): 68 """Gets the ops and kernels needed from the list of NodeDef.""" 69 ops = set() 70 for node_def in node_defs: 71 if not node_def.device: 72 node_def.device = '/cpu:0' 73 kernel_class = _pywrap_kernel_registry.TryFindKernelClass( 74 node_def.SerializeToString()) 75 op = str(node_def.op) 76 if kernel_class or op in OPS_WITHOUT_KERNEL_ALLOWLIST: 77 op_and_kernel = (op, str(kernel_class.decode('utf-8')) 78 if kernel_class else None) 79 ops.add(op_and_kernel) 80 else: 81 tf_logging.warning('Warning: no kernel found for op %s', op) 82 return ops 83 84 85def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): 86 """Gets the ops and kernels needed from the model files.""" 87 ops = set() 88 89 for proto_file in proto_files: 90 tf_logging.info('Loading proto file %s', proto_file) 91 # Load ops list file. 92 if proto_fileformat == 'ops_list': 93 ops = ops.union(_get_ops_from_ops_list(proto_file)) 94 continue 95 96 # Load GraphDef. 97 file_data = gfile.GFile(proto_file, 'rb').read() 98 if proto_fileformat == 'rawproto': 99 graph_def = graph_pb2.GraphDef.FromString(file_data) 100 else: 101 assert proto_fileformat == 'textproto' 102 graph_def = text_format.Parse(file_data, graph_pb2.GraphDef()) 103 ops = ops.union(_get_ops_from_graphdef(graph_def)) 104 105 # Add default ops. 106 if default_ops_str and default_ops_str != 'all': 107 for s in default_ops_str.split(','): 108 op, kernel = s.split(':') 109 op_and_kernel = (op, kernel) 110 if op_and_kernel not in ops: 111 ops.add(op_and_kernel) 112 113 return list(sorted(ops)) 114 115 116def get_header_from_ops_and_kernels(ops_and_kernels, 117 include_all_ops_and_kernels): 118 """Returns a header for use with tensorflow SELECTIVE_REGISTRATION. 119 120 Args: 121 ops_and_kernels: a set of (op_name, kernel_class_name) pairs to include. 122 include_all_ops_and_kernels: if True, ops_and_kernels is ignored and all op 123 kernels are included. 124 125 Returns: 126 the string of the header that should be written as ops_to_register.h. 127 """ 128 ops = set(op for op, _ in ops_and_kernels) 129 result_list = [] 130 131 def append(s): 132 result_list.append(s) 133 134 _, script_name = os.path.split(sys.argv[0]) 135 append('// This file was autogenerated by %s' % script_name) 136 append('#ifndef OPS_TO_REGISTER') 137 append('#define OPS_TO_REGISTER') 138 139 if include_all_ops_and_kernels: 140 append('#define SHOULD_REGISTER_OP(op) true') 141 append('#define SHOULD_REGISTER_OP_KERNEL(clz) true') 142 append('#define SHOULD_REGISTER_OP_GRADIENT true') 143 else: 144 line = """ 145 namespace { 146 constexpr const char* skip(const char* x) { 147 return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; 148 } 149 150 constexpr bool isequal(const char* x, const char* y) { 151 return (*skip(x) && *skip(y)) 152 ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1)) 153 : (!*skip(x) && !*skip(y)); 154 } 155 156 template<int N> 157 struct find_in { 158 static constexpr bool f(const char* x, const char* const y[N]) { 159 return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1); 160 } 161 }; 162 163 template<> 164 struct find_in<0> { 165 static constexpr bool f(const char* x, const char* const y[]) { 166 return false; 167 } 168 }; 169 } // end namespace 170 """ 171 line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' 172 for _, kernel_class in ops_and_kernels: 173 if kernel_class is None: 174 continue 175 line += '"%s",\n' % kernel_class 176 line += '};' 177 append(line) 178 append('#define SHOULD_REGISTER_OP_KERNEL(clz) ' 179 '(find_in<sizeof(kNecessaryOpKernelClasses) ' 180 '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, ' 181 'kNecessaryOpKernelClasses))') 182 append('') 183 184 append('constexpr inline bool ShouldRegisterOp(const char op[]) {') 185 append(' return false') 186 for op in sorted(ops): 187 append(' || isequal(op, "%s")' % op) 188 append(' ;') 189 append('}') 190 append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') 191 append('') 192 193 append('#define SHOULD_REGISTER_OP_GRADIENT ' + 194 ('true' if 'SymbolicGradient' in ops else 'false')) 195 196 append('#endif') 197 return '\n'.join(result_list) 198 199 200def get_header(graphs, 201 proto_fileformat='rawproto', 202 default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'): 203 """Computes a header for use with tensorflow SELECTIVE_REGISTRATION. 204 205 Args: 206 graphs: a list of paths to GraphDef files to include. 207 proto_fileformat: optional format of proto file, either 'textproto', 208 'rawproto' (default) or ops_list. The ops_list is the file contain the 209 list of ops in JSON format, Ex: "[["Transpose", "TransposeCpuOp"]]". 210 default_ops: optional comma-separated string of operator:kernel pairs to 211 always include implementation for. Pass 'all' to have all operators and 212 kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'. 213 214 Returns: 215 the string of the header that should be written as ops_to_register.h. 216 """ 217 ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops) 218 if not ops_and_kernels: 219 print('Error reading graph!') 220 return 1 221 222 return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all') 223