1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions used by multiple tflite test files.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.lite.python import schema_py_generated as schema_fb 22from tensorflow.lite.python import schema_util 23from tensorflow.lite.tools import visualize 24 25 26def get_ops_list(model_data): 27 """Returns a set of ops in the tflite model data.""" 28 model = schema_fb.Model.GetRootAsModel(model_data, 0) 29 op_set = set() 30 31 for subgraph_idx in range(model.SubgraphsLength()): 32 subgraph = model.Subgraphs(subgraph_idx) 33 for op_idx in range(subgraph.OperatorsLength()): 34 op = subgraph.Operators(op_idx) 35 opcode = model.OperatorCodes(op.OpcodeIndex()) 36 builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) 37 if builtin_code == schema_fb.BuiltinOperator.CUSTOM: 38 opname = opcode.CustomCode().decode("utf-8") 39 op_set.add(opname) 40 else: 41 op_set.add(visualize.BuiltinCodeToName(builtin_code)) 42 return op_set 43 44 45def get_output_shapes(model_data): 46 """Returns a list of output shapes in the tflite model data.""" 47 model = schema_fb.Model.GetRootAsModel(model_data, 0) 48 49 output_shapes = [] 50 for subgraph_idx in range(model.SubgraphsLength()): 51 subgraph = model.Subgraphs(subgraph_idx) 52 for output_idx in range(subgraph.OutputsLength()): 53 output_tensor_idx = subgraph.Outputs(output_idx) 54 output_tensor = subgraph.Tensors(output_tensor_idx) 55 output_shapes.append(output_tensor.ShapeAsNumpy().tolist()) 56 57 return output_shapes 58