• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2019 Guo Yejun
2#
3# This file is part of FFmpeg.
4#
5# FFmpeg is free software; you can redistribute it and/or
6# modify it under the terms of the GNU Lesser General Public
7# License as published by the Free Software Foundation; either
8# version 2.1 of the License, or (at your option) any later version.
9#
10# FFmpeg is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13# Lesser General Public License for more details.
14#
15# You should have received a copy of the GNU Lesser General Public
16# License along with FFmpeg; if not, write to the Free Software
17# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18# ==============================================================================
19
20import tensorflow as tf
21import numpy as np
22import sys, struct
23import convert_header as header
24
25__all__ = ['convert_from_tensorflow']
26
27class Operand(object):
28    IOTYPE_INPUT = 1
29    IOTYPE_OUTPUT = 2
30    IOTYPE_INTERMEDIATE = IOTYPE_INPUT | IOTYPE_OUTPUT
31    DTYPE_FLOAT = 1
32    DTYPE_UINT8 = 4
33    index = 0
34    def __init__(self, name, dtype, dims):
35        self.name = name
36        self.dtype = dtype
37        self.dims = dims
38        self.iotype = 0
39        self.used_count = 0
40        self.index = Operand.index
41        Operand.index = Operand.index + 1
42        self.iotype2str = {Operand.IOTYPE_INPUT: 'in', Operand.IOTYPE_OUTPUT: 'out', Operand.IOTYPE_INTERMEDIATE: 'inout'}
43        self.dtype2str = {Operand.DTYPE_FLOAT: 'DT_FLOAT', Operand.DTYPE_UINT8: 'DT_UINT8'}
44
45    def add_iotype(self, iotype):
46        self.iotype = self.iotype | iotype
47        if iotype == Operand.IOTYPE_INPUT:
48            self.used_count = self.used_count + 1
49
50    def __str__(self):
51        return "{}: (name: {}, iotype: {}, dtype: {}, dims: ({},{},{},{}) used_count: {})".format(self.index,
52                            self.name, self.iotype2str[self.iotype], self.dtype2str[self.dtype],
53                            self.dims[0], self.dims[1], self.dims[2], self.dims[3], self.used_count)
54
55    def __lt__(self, other):
56        return self.index < other.index
57
58class TFConverter:
59    def __init__(self, graph_def, nodes, outfile, dump4tb):
60        self.graph_def = graph_def
61        self.nodes = nodes
62        self.outfile = outfile
63        self.dump4tb = dump4tb
64        self.layer_number = 0
65        self.output_names = []
66        self.name_node_dict = {}
67        self.edges = {}
68        self.conv_activations = {'Relu':0, 'Tanh':1, 'Sigmoid':2, 'None':3, 'LeakyRelu':4}
69        self.conv_paddings = {'VALID':0, 'SAME':1}
70        self.converted_nodes = set()
71        self.conv2d_scope_names = set()
72        self.conv2d_scopename_inputname_dict = {}
73        self.op2code = {'Conv2D':1, 'DepthToSpace':2, 'MirrorPad':3, 'Maximum':4, 'MathBinary':5, 'MathUnary':6}
74        self.mathbin2code = {'Sub':0, 'Add':1, 'Mul':2, 'RealDiv':3, 'Minimum':4}
75        self.mathun2code  = {'Abs':0}
76        self.mirrorpad_mode = {'CONSTANT':0, 'REFLECT':1, 'SYMMETRIC':2}
77        self.name_operand_dict = {}
78
79
80    def add_operand(self, name, type):
81        node = self.name_node_dict[name]
82        if name not in self.name_operand_dict:
83            dtype = node.attr['dtype'].type
84            if dtype == 0:
85                dtype = node.attr['T'].type
86            dims = [-1,-1,-1,-1]
87            if 'shape' in node.attr:
88                dims[0] = node.attr['shape'].shape.dim[0].size
89                dims[1] = node.attr['shape'].shape.dim[1].size
90                dims[2] = node.attr['shape'].shape.dim[2].size
91                dims[3] = node.attr['shape'].shape.dim[3].size
92            operand = Operand(name, dtype, dims)
93            self.name_operand_dict[name] = operand;
94        self.name_operand_dict[name].add_iotype(type)
95        return self.name_operand_dict[name].index
96
97
98    def dump_for_tensorboard(self):
99        graph = tf.get_default_graph()
100        tf.import_graph_def(self.graph_def, name="")
101        tf.summary.FileWriter('/tmp/graph', graph)
102        print('graph saved, run "tensorboard --logdir=/tmp/graph" to see it')
103
104
105    def get_conv2d_params(self, conv2d_scope_name):
106        knode = self.name_node_dict[conv2d_scope_name + '/kernel']
107        bnode = self.name_node_dict[conv2d_scope_name + '/bias']
108
109        if conv2d_scope_name + '/dilation_rate' in self.name_node_dict:
110            dnode = self.name_node_dict[conv2d_scope_name + '/dilation_rate']
111        else:
112            dnode = None
113
114        # the BiasAdd name is possible be changed into the output name,
115        # if activation is None, and BiasAdd.next is the last op which is Identity
116        if conv2d_scope_name + '/BiasAdd' in self.edges:
117            anode = self.edges[conv2d_scope_name + '/BiasAdd'][0]
118            if anode.op not in self.conv_activations:
119                anode = None
120        else:
121            anode = None
122        return knode, bnode, dnode, anode
123
124
125    def dump_complex_conv2d_to_file(self, node, f):
126        assert(node.op == 'Conv2D')
127        self.layer_number = self.layer_number + 1
128        self.converted_nodes.add(node.name)
129
130        scope_name = TFConverter.get_scope_name(node.name)
131        #knode for kernel, bnode for bias, dnode for dilation, anode for activation
132        knode, bnode, dnode, anode = self.get_conv2d_params(scope_name)
133
134        if dnode is not None:
135            dilation = struct.unpack('i', dnode.attr['value'].tensor.tensor_content[0:4])[0]
136        else:
137            dilation = 1
138
139        if anode is not None:
140            activation = anode.op
141        else:
142            activation = 'None'
143
144        padding = node.attr['padding'].s.decode("utf-8")
145        # conv2d with dilation > 1 generates tens of nodes, not easy to parse them, so use this tricky method.
146        if dilation > 1 and scope_name + '/stack' in self.name_node_dict:
147            if self.name_node_dict[scope_name + '/stack'].op == "Const":
148                padding = 'SAME'
149        padding = self.conv_paddings[padding]
150
151        ktensor = knode.attr['value'].tensor
152        filter_height = ktensor.tensor_shape.dim[0].size
153        filter_width = ktensor.tensor_shape.dim[1].size
154        in_channels = ktensor.tensor_shape.dim[2].size
155        out_channels = ktensor.tensor_shape.dim[3].size
156        kernel = np.frombuffer(ktensor.tensor_content, dtype=np.float32)
157        kernel = kernel.reshape(filter_height, filter_width, in_channels, out_channels)
158        kernel = np.transpose(kernel, [3, 0, 1, 2])
159
160        has_bias = 1
161        np.array([self.op2code[node.op], dilation, padding, self.conv_activations[activation], in_channels, out_channels, filter_height, has_bias], dtype=np.uint32).tofile(f)
162        kernel.tofile(f)
163
164        btensor = bnode.attr['value'].tensor
165        if btensor.tensor_shape.dim[0].size == 1:
166            bias = struct.pack("f", btensor.float_val[0])
167        else:
168            bias = btensor.tensor_content
169        f.write(bias)
170
171        input_name = self.conv2d_scopename_inputname_dict[scope_name]
172        input_operand_index = self.add_operand(input_name, Operand.IOTYPE_INPUT)
173
174        if anode is not None:
175            output_operand_index = self.add_operand(anode.name, Operand.IOTYPE_OUTPUT)
176        else:
177            output_operand_index = self.add_operand(self.edges[bnode.name][0].name, Operand.IOTYPE_OUTPUT)
178        np.array([input_operand_index, output_operand_index], dtype=np.uint32).tofile(f)
179
180
181    def dump_simple_conv2d_to_file(self, node, f):
182        assert(node.op == 'Conv2D')
183        self.layer_number = self.layer_number + 1
184        self.converted_nodes.add(node.name)
185
186        node0 = self.name_node_dict[node.input[0]]
187        node1 = self.name_node_dict[node.input[1]]
188        if node0.op == 'Const':
189            knode = node0
190            input_name = node.input[1]
191        else:
192            knode = node1
193            input_name = node.input[0]
194
195        ktensor = knode.attr['value'].tensor
196        filter_height = ktensor.tensor_shape.dim[0].size
197        filter_width = ktensor.tensor_shape.dim[1].size
198        in_channels = ktensor.tensor_shape.dim[2].size
199        out_channels = ktensor.tensor_shape.dim[3].size
200        if filter_height * filter_width * in_channels * out_channels == 1:
201            kernel = np.float32(ktensor.float_val[0])
202        else:
203            kernel = np.frombuffer(ktensor.tensor_content, dtype=np.float32)
204        kernel = kernel.reshape(filter_height, filter_width, in_channels, out_channels)
205        kernel = np.transpose(kernel, [3, 0, 1, 2])
206
207        has_bias = 0
208        dilation = 1
209        padding = node.attr['padding'].s.decode("utf-8")
210        np.array([self.op2code[node.op], dilation, self.conv_paddings[padding], self.conv_activations['None'],
211                  in_channels, out_channels, filter_height, has_bias], dtype=np.uint32).tofile(f)
212        kernel.tofile(f)
213
214        input_operand_index = self.add_operand(input_name, Operand.IOTYPE_INPUT)
215        output_operand_index = self.add_operand(node.name, Operand.IOTYPE_OUTPUT)
216        np.array([input_operand_index, output_operand_index], dtype=np.uint32).tofile(f)
217
218
219    def dump_depth2space_to_file(self, node, f):
220        assert(node.op == 'DepthToSpace')
221        self.layer_number = self.layer_number + 1
222        block_size = node.attr['block_size'].i
223        np.array([self.op2code[node.op], block_size], dtype=np.uint32).tofile(f)
224        self.converted_nodes.add(node.name)
225        input_operand_index = self.add_operand(node.input[0], Operand.IOTYPE_INPUT)
226        output_operand_index = self.add_operand(node.name, Operand.IOTYPE_OUTPUT)
227        np.array([input_operand_index, output_operand_index], dtype=np.uint32).tofile(f)
228
229
230    def dump_mirrorpad_to_file(self, node, f):
231        assert(node.op == 'MirrorPad')
232        self.layer_number = self.layer_number + 1
233        mode = node.attr['mode'].s
234        mode = self.mirrorpad_mode[mode.decode("utf-8")]
235        np.array([self.op2code[node.op], mode], dtype=np.uint32).tofile(f)
236        pnode = self.name_node_dict[node.input[1]]
237        self.converted_nodes.add(pnode.name)
238        paddings = pnode.attr['value'].tensor.tensor_content
239        f.write(paddings)
240        self.converted_nodes.add(node.name)
241        input_operand_index = self.add_operand(node.input[0], Operand.IOTYPE_INPUT)
242        output_operand_index = self.add_operand(node.name, Operand.IOTYPE_OUTPUT)
243        np.array([input_operand_index, output_operand_index], dtype=np.uint32).tofile(f)
244
245
246    def dump_maximum_to_file(self, node, f):
247        assert(node.op == 'Maximum')
248        self.layer_number = self.layer_number + 1
249        ynode = self.name_node_dict[node.input[1]]
250        y = ynode.attr['value'].tensor.float_val[0]
251        np.array([self.op2code[node.op]], dtype=np.uint32).tofile(f)
252        np.array([y], dtype=np.float32).tofile(f)
253        self.converted_nodes.add(node.name)
254        input_operand_index = self.add_operand(node.input[0], Operand.IOTYPE_INPUT)
255        output_operand_index = self.add_operand(node.name, Operand.IOTYPE_OUTPUT)
256        np.array([input_operand_index, output_operand_index], dtype=np.uint32).tofile(f)
257
258
259    def dump_mathbinary_to_file(self, node, f):
260        self.layer_number = self.layer_number + 1
261        self.converted_nodes.add(node.name)
262        i0_node = self.name_node_dict[node.input[0]]
263        i1_node = self.name_node_dict[node.input[1]]
264        np.array([self.op2code['MathBinary'], self.mathbin2code[node.op]], dtype=np.uint32).tofile(f)
265        if i0_node.op == 'Const':
266            scalar = i0_node.attr['value'].tensor.float_val[0]
267            np.array([1], dtype=np.uint32).tofile(f)            # broadcast: 1
268            np.array([scalar], dtype=np.float32).tofile(f)
269            np.array([0], dtype=np.uint32).tofile(f)            # broadcast: 0
270            input_operand_index = self.add_operand(i1_node.name, Operand.IOTYPE_INPUT)
271            np.array([input_operand_index], dtype=np.uint32).tofile(f)
272        elif i1_node.op == 'Const':
273            scalar = i1_node.attr['value'].tensor.float_val[0]
274            np.array([0], dtype=np.uint32).tofile(f)
275            input_operand_index = self.add_operand(i0_node.name, Operand.IOTYPE_INPUT)
276            np.array([input_operand_index], dtype=np.uint32).tofile(f)
277            np.array([1], dtype=np.uint32).tofile(f)
278            np.array([scalar], dtype=np.float32).tofile(f)
279        else:
280            np.array([0], dtype=np.uint32).tofile(f)
281            input_operand_index = self.add_operand(i0_node.name, Operand.IOTYPE_INPUT)
282            np.array([input_operand_index], dtype=np.uint32).tofile(f)
283            np.array([0], dtype=np.uint32).tofile(f)
284            input_operand_index = self.add_operand(i1_node.name, Operand.IOTYPE_INPUT)
285            np.array([input_operand_index], dtype=np.uint32).tofile(f)
286        output_operand_index = self.add_operand(node.name, Operand.IOTYPE_OUTPUT)
287        np.array([output_operand_index], dtype=np.uint32).tofile(f)
288
289
290    def dump_mathunary_to_file(self, node, f):
291        self.layer_number = self.layer_number + 1
292        self.converted_nodes.add(node.name)
293        i0_node = self.name_node_dict[node.input[0]]
294        np.array([self.op2code['MathUnary'], self.mathun2code[node.op]], dtype=np.uint32).tofile(f)
295        input_operand_index = self.add_operand(i0_node.name, Operand.IOTYPE_INPUT)
296        np.array([input_operand_index], dtype=np.uint32).tofile(f)
297        output_operand_index = self.add_operand(node.name, Operand.IOTYPE_OUTPUT)
298        np.array([output_operand_index],dtype=np.uint32).tofile(f)
299
300
301    def dump_layers_to_file(self, f):
302        for node in self.nodes:
303            if node.name in self.converted_nodes:
304                continue
305
306            # conv2d with dilation generates very complex nodes, so handle it in special
307            if self.in_conv2d_scope(node.name):
308                if node.op == 'Conv2D':
309                    self.dump_complex_conv2d_to_file(node, f)
310                continue
311
312            if node.op == 'Conv2D':
313                self.dump_simple_conv2d_to_file(node, f)
314            elif node.op == 'DepthToSpace':
315                self.dump_depth2space_to_file(node, f)
316            elif node.op == 'MirrorPad':
317                self.dump_mirrorpad_to_file(node, f)
318            elif node.op == 'Maximum':
319                self.dump_maximum_to_file(node, f)
320            elif node.op in self.mathbin2code:
321                self.dump_mathbinary_to_file(node, f)
322            elif node.op in self.mathun2code:
323                self.dump_mathunary_to_file(node, f)
324
325
326    def dump_operands_to_file(self, f):
327            operands = sorted(self.name_operand_dict.values())
328            for operand in operands:
329                #print('{}'.format(operand))
330                np.array([operand.index, len(operand.name)], dtype=np.uint32).tofile(f)
331                f.write(operand.name.encode('utf-8'))
332                np.array([operand.iotype, operand.dtype], dtype=np.uint32).tofile(f)
333                np.array([operand.dims[0], operand.dims[1], operand.dims[2], operand.dims[3]], dtype=np.uint32).tofile(f)
334
335
336    def dump_to_file(self):
337        with open(self.outfile, 'wb') as f:
338            f.write(header.str.encode('utf-8'))
339            np.array([header.major, header.minor], dtype=np.uint32).tofile(f)
340            self.dump_layers_to_file(f)
341            self.dump_operands_to_file(f)
342            np.array([self.layer_number, len(self.name_operand_dict)], dtype=np.uint32).tofile(f)
343
344
345    def generate_name_node_dict(self):
346        for node in self.nodes:
347            self.name_node_dict[node.name] = node
348
349
350    def generate_output_names(self):
351        used_names = []
352        for node in self.nodes:
353            for input in node.input:
354                used_names.append(input)
355
356        for node in self.nodes:
357            if node.name not in used_names:
358                self.output_names.append(node.name)
359
360
361    def remove_identity(self):
362        id_nodes = []
363        id_dict = {}
364        for node in self.nodes:
365            if node.op == 'Identity':
366                name = node.name
367                input = node.input[0]
368                id_nodes.append(node)
369                # do not change the output name
370                if name in self.output_names:
371                    self.name_node_dict[input].name = name
372                    self.name_node_dict[name] = self.name_node_dict[input]
373                    del self.name_node_dict[input]
374                else:
375                    id_dict[name] = input
376
377        for idnode in id_nodes:
378            self.nodes.remove(idnode)
379
380        for node in self.nodes:
381            for i in range(len(node.input)):
382                input = node.input[i]
383                if input in id_dict:
384                    node.input[i] = id_dict[input]
385
386
387    def generate_edges(self):
388        for node in self.nodes:
389            for input in node.input:
390                if input in self.edges:
391                    self.edges[input].append(node)
392                else:
393                    self.edges[input] = [node]
394
395
396    @staticmethod
397    def get_scope_name(name):
398        index = name.rfind('/')
399        if index == -1:
400            return ""
401        return name[0:index]
402
403
404    def in_conv2d_scope(self, name):
405        inner_scope = TFConverter.get_scope_name(name)
406        if inner_scope == "":
407            return False;
408        for scope in self.conv2d_scope_names:
409            index = inner_scope.find(scope)
410            if index == 0:
411                return True
412        return False
413
414
415    def generate_conv2d_scope_info(self):
416        # mostly, conv2d is a sub block in graph, get the scope name
417        for node in self.nodes:
418            if node.op == 'Conv2D':
419                scope = TFConverter.get_scope_name(node.name)
420                # for the case tf.nn.conv2d is called directly
421                if scope == '':
422                    continue
423                # for the case tf.nn.conv2d is called within a scope
424                if scope + '/kernel' not in self.name_node_dict:
425                    continue
426                self.conv2d_scope_names.add(scope)
427
428        # get the input name to the conv2d sub block
429        for node in self.nodes:
430            scope = TFConverter.get_scope_name(node.name)
431            if scope in self.conv2d_scope_names:
432                if node.op == 'Conv2D' or node.op == 'Shape':
433                    for inp in node.input:
434                        if TFConverter.get_scope_name(inp) != scope:
435                            self.conv2d_scopename_inputname_dict[scope] = inp
436
437
438    def run(self):
439        self.generate_name_node_dict()
440        self.generate_output_names()
441        self.remove_identity()
442        self.generate_edges()
443        self.generate_conv2d_scope_info()
444
445        if self.dump4tb:
446            self.dump_for_tensorboard()
447
448        self.dump_to_file()
449
450
451def convert_from_tensorflow(infile, outfile, dump4tb):
452    with open(infile, 'rb') as f:
453        # read the file in .proto format
454        graph_def = tf.GraphDef()
455        graph_def.ParseFromString(f.read())
456        nodes = graph_def.node
457
458    converter = TFConverter(graph_def, nodes, outfile, dump4tb)
459    converter.run()
460