1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""VTS testcase generator 17 18Implements VTS test backend. Shares most logic with the CTS test 19generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh; 20See that script for details on how this script is used. 21 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27import argparse 28from functools import reduce 29import math 30import os 31import struct 32import sys 33import contextlib 34import test_generator 35import pprint 36# Stuff from test generator 37from test_generator import Configuration 38from test_generator import Example 39from test_generator import Float32Scalar 40from test_generator import IgnoredOutput 41from test_generator import Input 42from test_generator import Int32Scalar 43from test_generator import Internal 44from test_generator import Model 45from test_generator import Operand 46from test_generator import Output 47from test_generator import Parameter 48from test_generator import smart_open 49 50# Take a model from command line 51def import_source(): 52 parser = argparse.ArgumentParser() 53 parser.add_argument("spec", help="the spec file") 54 parser.add_argument( 55 "-m", "--model", help="the output model file", default="-") 56 parser.add_argument( 57 "-e", "--example", help="the output example file", default="-") 58 args = parser.parse_args() 59 60 if os.path.exists(args.spec): 61 test_generator.FileNames.SpecFile = os.path.basename(args.spec) 62 exec (open(args.spec).read()) 63 64 return (args.model, args.example) 65 66# Generate operands in VTS format 67def generate_vts_operands(): 68 # Dump operand definitions 69 op_def = """\ 70 {{ 71 .type = OperandType::{operand_type}, 72 .dimensions = {shape}, 73 .numberOfConsumers = {no_consumers}, 74 .scale = {scale}, 75 .zeroPoint = {zero_point}, 76 .lifetime = OperandLifeTime::{lifetime}, 77 .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}}, 78 }}""" 79 offset = 0 80 op_definitions = [] 81 for o in Operand.operands.objects(): 82 ty = o.type 83 no_consumers = len(o.outs) if o.traversable() else 0 84 lifetime = o.lifetime() 85 length = ty.get_size() if o.is_weight() else 0 86 real_shape, scale, zero_point = ty.get_parsed_shape() 87 scale = float(scale) 88 zero_point = int(zero_point) 89 op = { 90 "operand_type": ty.get_element_type(), 91 "shape": "{%s}" % real_shape, 92 "no_consumers": no_consumers, 93 "scale": test_generator.pretty_print_as_float(scale), 94 "zero_point": str(int(zero_point)), 95 "lifetime": lifetime, 96 "offset": offset if o.is_weight() else 0, 97 "length": length 98 } 99 offset += length 100 op_definitions.append(op_def.format(**op)) 101 102 op_vec = """\ 103 const std::vector<Operand> operands = {{ 104{0} 105 }};""".format(",\n".join(op_definitions)) 106 return op_vec 107 108# Generate VTS operand values 109def generate_vts_operand_values(): 110 weights = [o for o in Operand.operands.objects() if o.is_weight()] 111 binit = [] 112 for w in weights: 113 ty = w.type.get_element_type() 114 if ty == "TENSOR_QUANT8_ASYMM": 115 binit += w.initializer 116 elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32"}: 117 fmt = "f" if (ty == "TENSOR_FLOAT32" or ty == "FLOAT32") else "i" 118 for f in w.initializer: 119 binit += [int(x) for x in struct.pack(fmt, f)] 120 else: 121 assert 0 and "Unsupported VTS operand type" 122 123 init_defs = ", ".join([str(x) for x in binit]) 124 if (init_defs != ""): 125 init_defs = "\n %s\n " % init_defs 126 byte_vec_fmt = """{%s}""" % init_defs 127 return byte_vec_fmt 128 129# Generate VTS operations 130class VTSOps(object): 131 vts_ops = [] 132 def generate_vts_operation(op): 133 try: 134 opcode =op.optype 135 except AttributeError: # not an op, but things like weights 136 return 137 op_fmt = """\ 138 {{ 139 .type = OperationType::{op_code}, 140 .inputs = {{{ins}}}, 141 .outputs = {{{outs}}}, 142 }}""" 143 op_content = { 144 'op_code': op.optype, 145 'op_type': op.type.get_element_type(), 146 'ins': ", ".join([str(x.ID()) for x in op.ins]), 147 'outs': ", ".join([str(x.ID()) for x in op.outs]), 148 } 149 VTSOps.vts_ops.append(op_fmt.format(**op_content)) 150 return True 151 152def generate_vts_operations(model_file): 153 test_generator.TopologicalSort(lambda x: VTSOps.generate_vts_operation(x)) 154 return ",\n".join(VTSOps.vts_ops) 155 156 157def generate_vts_model(model_file): 158 operand_values_fmt = "" 159 if Configuration.useSHM(): 160 # Boilerplate code for passing weights in shared memory 161 operand_values_fmt = """\ 162 std::vector<uint8_t> operandValues = {{}}; 163 const uint8_t data[] = {operand_values}; 164 165 // Allocate segment of android shared memory, wrapped in hidl_memory. 166 // This object will be automatically freed when sharedMemory is destroyed. 167 hidl_memory sharedMemory = allocateSharedMemory(sizeof(data)); 168 169 // Mmap ashmem into usable address and hold it within the mappedMemory object. 170 // MappedMemory will automatically munmap the memory when it is destroyed. 171 sp<IMemory> mappedMemory = mapMemory(sharedMemory); 172 173 if (mappedMemory != nullptr) {{ 174 // Retrieve the mmapped pointer. 175 uint8_t* mappedPointer = 176 static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer())); 177 178 if (mappedPointer != nullptr) {{ 179 // Acquire the write lock for the shared memory segment, upload the data, 180 // and release the lock. 181 mappedMemory->update(); 182 std::copy(data, data + sizeof(data), mappedPointer); 183 mappedMemory->commit(); 184 }} 185 }} 186 187 const std::vector<hidl_memory> pools = {{sharedMemory}}; 188""" 189 else: 190 # Passing weights via operandValues 191 operand_values_fmt = """\ 192 std::vector<uint8_t> operandValues = {operand_values}; 193 const std::vector<hidl_memory> pools = {{}}; 194""" 195 196 operand_values_val = { 197 'operand_values': generate_vts_operand_values() 198 } 199 operand_values = operand_values_fmt.format(**operand_values_val) 200 # operand_values = operand_values_fmt 201 model_fmt = """\ 202// Generated code. Do not edit 203// Create the model 204Model createTestModel() {{ 205{operand_decls} 206 207 const std::vector<Operation> operations = {{ 208{operations} 209 }}; 210 211 const std::vector<uint32_t> inputIndexes = {{{input_indices}}}; 212 const std::vector<uint32_t> outputIndexes = {{{output_indices}}}; 213{operand_values} 214 return {{ 215 .operands = operands, 216 .operations = operations, 217 .inputIndexes = inputIndexes, 218 .outputIndexes = outputIndexes, 219 .operandValues = operandValues, 220 .pools = pools,{relaxed_field} 221 }}; 222}} 223""" 224 model = { 225 "operations": generate_vts_operations(sys.stdout), 226 "operand_decls": generate_vts_operands(), 227 "operand_values": operand_values, 228 "output_indices": ", ".join([str(i.ID()) for i in Output.get_outputs()]), 229 "input_indices": ", ".join([str(i.ID()) for i in Input.get_inputs(True)]), 230 "relaxed_field": 231 "\n .relaxComputationFloat32toFloat16 = true," if (Model.isRelaxed()) else "" 232 } 233 print(model_fmt.format(**model), file = model_file) 234 235def generate_vts(model_file): 236 generate_vts_model(model_file) 237 print (IgnoredOutput.gen_ignored(), file=model_file) 238 239if __name__ == "__main__": 240 (model, example) = import_source() 241 print("Output VTS model: %s" % model, file=sys.stderr) 242 print("Output example:" + example, file=sys.stderr) 243 244 with smart_open(model) as model_file: 245 generate_vts(model_file) 246 with smart_open(example) as example_file: 247 Example.dump(example_file) 248