1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15from __future__ import absolute_import 16from __future__ import division 17from __future__ import print_function 18 19import os 20import tempfile 21 22import tensorflow.compat.v1 as tf 23from tensorflow.lite.toco import model_flags_pb2 24from tensorflow.lite.toco import toco_flags_pb2 25from tensorflow.lite.toco import types_pb2 26from tensorflow.python.platform import googletest 27from tensorflow.python.platform import resource_loader 28 29 30def TensorName(x): 31 """Get the canonical (non foo:0 name).""" 32 return x.name.split(":")[0] 33 34 35class TocoFromProtosTest(googletest.TestCase): 36 37 def _run(self, sess, in_tensor, out_tensor, should_succeed): 38 """Use toco binary to check conversion from graphdef to tflite. 39 40 Args: 41 sess: Active TensorFlow session containing graph. 42 in_tensor: TensorFlow tensor to use as input. 43 out_tensor: TensorFlow tensor to use as output. 44 should_succeed: Whether this is a valid conversion. 45 """ 46 # Build all protos and extract graphdef 47 graph_def = sess.graph_def 48 toco_flags = toco_flags_pb2.TocoFlags() 49 toco_flags.input_format = toco_flags_pb2.TENSORFLOW_GRAPHDEF 50 toco_flags.output_format = toco_flags_pb2.TFLITE 51 toco_flags.inference_input_type = types_pb2.FLOAT 52 toco_flags.inference_type = types_pb2.FLOAT 53 toco_flags.allow_custom_ops = True 54 model_flags = model_flags_pb2.ModelFlags() 55 input_array = model_flags.input_arrays.add() 56 input_array.name = TensorName(in_tensor) 57 input_array.shape.dims.extend(map(int, in_tensor.shape)) 58 model_flags.output_arrays.append(TensorName(out_tensor)) 59 # Shell out to run toco (in case it crashes) 60 with tempfile.NamedTemporaryFile() as fp_toco, \ 61 tempfile.NamedTemporaryFile() as fp_model, \ 62 tempfile.NamedTemporaryFile() as fp_input, \ 63 tempfile.NamedTemporaryFile() as fp_output: 64 fp_model.write(model_flags.SerializeToString()) 65 fp_toco.write(toco_flags.SerializeToString()) 66 fp_input.write(graph_def.SerializeToString()) 67 fp_model.flush() 68 fp_toco.flush() 69 fp_input.flush() 70 tflite_bin = resource_loader.get_path_to_datafile("toco_from_protos.par") 71 cmdline = " ".join([ 72 tflite_bin, fp_model.name, fp_toco.name, fp_input.name, fp_output.name 73 ]) 74 exitcode = os.system(cmdline) 75 if exitcode == 0: 76 stuff = fp_output.read() 77 self.assertEqual(stuff is not None, should_succeed) 78 else: 79 self.assertFalse(should_succeed) 80 81 def test_toco(self): 82 """Run a couple of TensorFlow graphs against TOCO through the python bin.""" 83 with tf.Session() as sess: 84 img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) 85 val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) 86 out = tf.identity(val, name="out") 87 out2 = tf.sin(val, name="out2") 88 # This is a valid model 89 self._run(sess, img, out, True) 90 # This uses an invalid function. 91 # TODO(aselle): Check to make sure a warning is included. 92 self._run(sess, img, out2, True) 93 # This is an identity graph, which doesn't work 94 self._run(sess, img, img, False) 95 96 97if __name__ == "__main__": 98 googletest.main() 99