1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""VTS testcase generator 17 18Implements VTS test backend. Shares most logic with the CTS test 19generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh; 20See that script for details on how this script is used. 21 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27import argparse 28from functools import reduce 29import math 30import numpy as np 31import os 32import re 33import struct 34import sys 35import contextlib 36import pprint 37 38# Stuff from test generator 39import test_generator as tg 40from test_generator import ActivationConverter 41from test_generator import BoolScalar 42from test_generator import Configuration 43from test_generator import DataTypeConverter 44from test_generator import DataLayoutConverter 45from test_generator import Example 46from test_generator import Float16Scalar 47from test_generator import Float32Scalar 48from test_generator import Float32Vector 49from test_generator import IgnoredOutput 50from test_generator import Input 51from test_generator import Int32Scalar 52from test_generator import Int32Vector 53from test_generator import Internal 54from test_generator import Model 55from test_generator import Operand 56from test_generator import Output 57from test_generator import Parameter 58from test_generator import ParameterAsInputConverter 59from test_generator import RelaxedModeConverter 60from test_generator import SmartOpen 61from test_generator import SymmPerChannelQuantParams 62 63# Dumping methods that shared with CTS generator 64from cts_generator import DumpCtsExample 65from cts_generator import DumpCtsIsIgnored 66 67# Take a model from command line 68def ParseCmdLine(): 69 parser = argparse.ArgumentParser() 70 parser.add_argument("spec", help="the spec file") 71 parser.add_argument( 72 "-m", "--model", help="the output model file", default="-") 73 parser.add_argument( 74 "-e", "--example", help="the output example file", default="-") 75 parser.add_argument( 76 "-t", "--test", help="the output test file", default="-") 77 args = parser.parse_args() 78 tg.FileNames.InitializeFileLists( 79 args.spec, args.model, args.example, args.test) 80 81# Generate operands in VTS format 82def generate_vts_operands(model): 83 # Dump operand definitions 84 op_def = """\ 85 {{ 86 .type = OperandType::{operand_type}, 87 .dimensions = {shape}, 88 .numberOfConsumers = {no_consumers}, 89 .scale = {scale}, 90 .zeroPoint = {zero_point}, 91 .lifetime = OperandLifeTime::{lifetime}, 92 .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},{extraParams} 93 }}""" 94 offset = 0 95 op_definitions = [] 96 extra_params_definitions = [] 97 for index, o in enumerate(model.operands): 98 length = o.type.GetByteSize() if isinstance(o, Parameter) else 0 99 add_extra_params = o.type.extraParams is not None and not o.type.extraParams.hide 100 op = { 101 "operand_type": o.type.type, 102 "shape": o.type.GetDimensionsString(), 103 "no_consumers": len(o.outs), 104 "scale": tg.PrettyPrintAsFloat(o.type.scale), 105 "zero_point": str(int(o.type.zeroPoint)), 106 "lifetime": o.lifetime, 107 "offset": offset if isinstance(o, Parameter) else 0, 108 "length": length, 109 "extraParams": "" if not add_extra_params else "\n .extraParams = std::move(extraParams%d)," % (index,), 110 } 111 offset += length 112 op_definitions.append(op_def.format(**op)) 113 114 extra_params_def = """\ 115 Operand::ExtraParams extraParams{index}; 116 extraParams{index}.{setMethodName}({param}); 117""" 118 119 if add_extra_params: 120 ep = o.type.extraParams 121 op = { 122 "index": index, 123 "setMethodName": ep.GetVtsSetter(), 124 "param": ep.GetVtsConstructor(), 125 } 126 extra_params_definitions.append(extra_params_def.format(**op)) 127 128 op_vec = """{0}\ 129 const std::vector<Operand> operands = {{ 130{1} 131 }};""".format(",\n".join(extra_params_definitions), ",\n".join(op_definitions)) 132 return op_vec 133 134# Generate VTS operand values 135def generate_vts_operand_values(operands): 136 weights = [o for o in operands if isinstance(o, Parameter)] 137 binit = [] 138 for w in weights: 139 ty = w.type.type 140 if ty == "TENSOR_QUANT8_ASYMM": 141 binit += w.value 142 elif ty == "TENSOR_QUANT8_SYMM_PER_CHANNEL" or ty == "TENSOR_QUANT8_SYMM": 143 binit += [struct.pack("b", value)[0] for value in w.value] 144 elif ty == "BOOL" or ty == "TENSOR_BOOL8": 145 binit += [1 if x else 0 for x in w.value] 146 elif ty == "TENSOR_FLOAT16" or ty == "FLOAT16": 147 for f in w.value: 148 # The pack format for float16 is not available until Python 3.6. 149 binit += [int(x) for x in np.float16(f).tostring()] 150 elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32", "TENSOR_QUANT16_ASYMM"}: 151 if ty in ["TENSOR_FLOAT32", "FLOAT32"]: 152 fmt = "f" 153 elif ty in ["TENSOR_INT32", "INT32"]: 154 fmt = "i" 155 elif ty == "TENSOR_QUANT16_ASYMM": 156 fmt = "H" 157 for f in w.value: 158 binit += [int(x) for x in struct.pack(fmt, f)] 159 else: 160 assert 0 and "Unsupported VTS operand type" 161 162 init_defs = ", ".join([str(x) for x in binit]) 163 if (init_defs != ""): 164 init_defs = "\n %s\n " % init_defs 165 byte_vec_fmt = """{%s}""" % init_defs 166 return byte_vec_fmt 167 168# Generate VTS operations 169def generate_vts_operation(op, model): 170 op_fmt = """\ 171 {{ 172 .type = OperationType::{op_code}, 173 .inputs = {{{ins}}}, 174 .outputs = {{{outs}}}, 175 }}""" 176 op_content = { 177 'op_code': op.optype, 178 'ins': tg.GetJointStr(model.GetIndexOfOperands(op.ins)), 179 'outs': tg.GetJointStr(model.GetIndexOfOperands(op.outs)) 180 } 181 return op_fmt.format(**op_content) 182 183def generate_vts_operations(model): 184 vts_ops = [generate_vts_operation(op, model) for op in model.operations] 185 return ",\n".join(vts_ops) 186 187def generate_vts_model(model, model_file): 188 operand_values_fmt = "" 189 if Configuration.useSHM(): 190 # Boilerplate code for passing weights in shared memory 191 operand_values_fmt = """\ 192 std::vector<uint8_t> operandValues = {{}}; 193 const uint8_t data[] = {operand_values}; 194 195 // Allocate segment of android shared memory, wrapped in hidl_memory. 196 // This object will be automatically freed when sharedMemory is destroyed. 197 hidl_memory sharedMemory = allocateSharedMemory(sizeof(data)); 198 199 // Mmap ashmem into usable address and hold it within the mappedMemory object. 200 // MappedMemory will automatically munmap the memory when it is destroyed. 201 sp<IMemory> mappedMemory = mapMemory(sharedMemory); 202 203 if (mappedMemory != nullptr) {{ 204 // Retrieve the mmapped pointer. 205 uint8_t* mappedPointer = 206 static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer())); 207 208 if (mappedPointer != nullptr) {{ 209 // Acquire the write lock for the shared memory segment, upload the data, 210 // and release the lock. 211 mappedMemory->update(); 212 std::copy(data, data + sizeof(data), mappedPointer); 213 mappedMemory->commit(); 214 }} 215 }} 216 217 const std::vector<hidl_memory> pools = {{sharedMemory}}; 218""" 219 else: 220 # Passing weights via operandValues 221 operand_values_fmt = """\ 222 std::vector<uint8_t> operandValues = {operand_values}; 223 const std::vector<hidl_memory> pools = {{}}; 224""" 225 226 operand_values_val = { 227 'operand_values': generate_vts_operand_values(model.operands) 228 } 229 operand_values = operand_values_fmt.format(**operand_values_val) 230 # operand_values = operand_values_fmt 231 model_fmt = """\ 232// Create the model 233Model {create_test_model_name}() {{ 234{operand_decls} 235 236 const std::vector<Operation> operations = {{ 237{operations} 238 }}; 239 240 const std::vector<uint32_t> inputIndexes = {{{input_indices}}}; 241 const std::vector<uint32_t> outputIndexes = {{{output_indices}}}; 242{operand_values} 243 return {{ 244 .operands = operands, 245 .operations = operations, 246 .inputIndexes = inputIndexes, 247 .outputIndexes = outputIndexes, 248 .operandValues = operandValues, 249 .pools = pools,{relaxed_field} 250 }}; 251}} 252""" 253 model_dict = { 254 "create_test_model_name": str(model.createTestFunctionName), 255 "operations": generate_vts_operations(model), 256 "operand_decls": generate_vts_operands(model), 257 "operand_values": operand_values, 258 "output_indices": tg.GetJointStr(model.GetOutputsIndex()), 259 "input_indices": tg.GetJointStr(model.GetInputsIndex()), 260 "relaxed_field": 261 "\n .relaxComputationFloat32toFloat16 = true," if (model.isRelaxed) else "" 262 } 263 print(model_fmt.format(**model_dict), file = model_file) 264 265def generate_vts(model, model_file): 266 assert model.compiled 267 generate_vts_model(model, model_file) 268 DumpCtsIsIgnored(model, model_file) 269 270def generate_vts_test(example, test_file): 271 testTemplate = """\ 272TEST_F({test_case_name}, {test_name}) {{ 273 generated_tests::Execute(device, 274 {namespace}::{create_model_name}, 275 {namespace}::{is_ignored_name}, 276 {namespace}::get_{examples_name}(){test_dynamic_output_shape});\n}} 277 278TEST_F(ValidationTest, {test_name}) {{ 279 const Model model = {namespace}::{create_model_name}(); 280 const std::vector<Request> requests = createRequests({namespace}::get_{examples_name}()); 281 validateEverything(model, requests); 282}}\n 283""" 284 if example.model.hasDynamicOutputShape: 285 print("#ifdef NN_TEST_DYNAMIC_OUTPUT_SHAPE", file=test_fd) 286 print(testTemplate.format( 287 test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \ 288 else "NeuralnetworksHidlTest", 289 test_name=str(example.testName), 290 namespace=tg.FileNames.specName, 291 create_model_name=str(example.model.createTestFunctionName), 292 is_ignored_name=str(example.model.isIgnoredFunctionName), 293 examples_name=str(example.examplesName), 294 test_dynamic_output_shape=", true" if example.model.hasDynamicOutputShape else "" 295 ), file=test_fd) 296 if example.model.hasDynamicOutputShape: 297 print("#endif", file=test_fd) 298 299def InitializeFiles(model_fd, example_fd, test_fd): 300 fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit" 301 testFileHeader = """\ 302// Generated from: {spec_file}. 303namespace {spec_name} {{ 304// Generated {spec_name} test 305#include "{example_file}" 306// Generated model constructor 307#include "{model_file}" 308}} // namespace {spec_name}\n""" 309 # This regex is to remove prefix and get relative path for #include 310 pathRegex = r".*frameworks/ml/nn/(runtime/test/generated/)?" 311 specFileBase = os.path.basename(tg.FileNames.specFile) 312 print(fileHeader.format(spec_file=specFileBase), file=model_fd) 313 print(fileHeader.format(spec_file=specFileBase), file=example_fd) 314 print(testFileHeader.format( 315 spec_file=specFileBase, 316 model_file=re.sub(pathRegex, "", tg.FileNames.modelFile), 317 example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile), 318 spec_name=tg.FileNames.specName), file=test_fd) 319 320if __name__ == "__main__": 321 ParseCmdLine() 322 while tg.FileNames.NextFile(): 323 print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr) 324 exec (open(tg.FileNames.specFile, "r").read()) 325 print("Output VTS model: %s" % tg.FileNames.modelFile, file=sys.stderr) 326 print("Output example:" + tg.FileNames.exampleFile, file=sys.stderr) 327 with SmartOpen(tg.FileNames.modelFile) as model_fd, \ 328 SmartOpen(tg.FileNames.exampleFile) as example_fd, \ 329 SmartOpen(tg.FileNames.testFile, mode="a") as test_fd: 330 InitializeFiles(model_fd, example_fd, test_fd) 331 Example.DumpAllExamples( 332 DumpModel=generate_vts, model_fd=model_fd, 333 DumpExample=DumpCtsExample, example_fd=example_fd, 334 DumpTest=generate_vts_test, test_fd=test_fd) 335