1#!/usr/bin/python3 2 3# Copyright 2018, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""CTS testcase generator 18 19Implements CTS test backend. Invoked by ml/nn/runtime/test/specs/generate_tests.sh; 20See that script for details on how this script is used. 21 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27import argparse 28import math 29import os 30import re 31import sys 32import traceback 33 34# Stuff from test generator 35import test_generator as tg 36from test_generator import ActivationConverter 37from test_generator import BoolScalar 38from test_generator import Configuration 39from test_generator import DataTypeConverter 40from test_generator import DataLayoutConverter 41from test_generator import Example 42from test_generator import Float16Scalar 43from test_generator import Float32Scalar 44from test_generator import Float32Vector 45from test_generator import GetJointStr 46from test_generator import IgnoredOutput 47from test_generator import Input 48from test_generator import Int32Scalar 49from test_generator import Int32Vector 50from test_generator import Internal 51from test_generator import Model 52from test_generator import Operand 53from test_generator import Output 54from test_generator import Parameter 55from test_generator import ParameterAsInputConverter 56from test_generator import RelaxedModeConverter 57from test_generator import SmartOpen 58from test_generator import SymmPerChannelQuantParams 59 60def IndentedPrint(s, indent=2, *args, **kwargs): 61 print('\n'.join([" " * indent + i for i in s.split('\n')]), *args, **kwargs) 62 63# Take a model from command line 64def ParseCmdLine(): 65 parser = argparse.ArgumentParser() 66 parser.add_argument("spec", help="the spec file/directory") 67 parser.add_argument( 68 "-m", "--model", help="the output model file/directory", default="-") 69 parser.add_argument( 70 "-e", "--example", help="the output example file/directory", default="-") 71 parser.add_argument( 72 "-t", "--test", help="the output test file/directory", default="-") 73 parser.add_argument( 74 "-c", "--cts", help="the CTS TestGeneratedOneFile.cpp", default="-") 75 parser.add_argument( 76 "-f", "--force", help="force to regenerate all spec files", action="store_true") 77 # for slicing tool 78 parser.add_argument( 79 "-l", "--log", help="the optional log file", default="") 80 args = parser.parse_args() 81 tg.FileNames.InitializeFileLists( 82 args.spec, args.model, args.example, args.test, args.cts, args.log) 83 Configuration.force_regenerate = args.force 84 85def NeedRegenerate(): 86 if not all(os.path.exists(f) for f in \ 87 [tg.FileNames.modelFile, tg.FileNames.exampleFile, tg.FileNames.testFile]): 88 return True 89 specTime = os.path.getmtime(tg.FileNames.specFile) + 10 90 modelTime = os.path.getmtime(tg.FileNames.modelFile) 91 exampleTime = os.path.getmtime(tg.FileNames.exampleFile) 92 testTime = os.path.getmtime(tg.FileNames.testFile) 93 if all(t > specTime for t in [modelTime, exampleTime, testTime]): 94 return False 95 return True 96 97# Write headers for generated files, which are boilerplate codes only related to filenames 98def InitializeFiles(model_fd, example_fd, test_fd): 99 fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit" 100 testFileHeader = """\ 101#include "../../TestGenerated.h"\n 102namespace {spec_name} {{ 103// Generated {spec_name} test 104#include "{example_file}" 105// Generated model constructor 106#include "{model_file}" 107}} // namespace {spec_name}\n""" 108 # This regex is to remove prefix and get relative path for #include 109 pathRegex = r".*((frameworks/ml/nn/(runtime/test/)?)|(vendor/google/[a-z]*/test/))" 110 specFileBase = os.path.basename(tg.FileNames.specFile) 111 print(fileHeader.format(spec_file=specFileBase), file=model_fd) 112 print(fileHeader.format(spec_file=specFileBase), file=example_fd) 113 print(fileHeader.format(spec_file=specFileBase), file=test_fd) 114 print(testFileHeader.format( 115 model_file=re.sub(pathRegex, "", tg.FileNames.modelFile), 116 example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile), 117 spec_name=tg.FileNames.specName), file=test_fd) 118 119# Dump is_ignored function for IgnoredOutput 120def DumpCtsIsIgnored(model, model_fd): 121 isIgnoredTemplate = """\ 122inline bool {is_ignored_name}(int i) {{ 123 static std::set<int> ignore = {{{ignored_index}}}; 124 return ignore.find(i) != ignore.end();\n}}\n""" 125 print(isIgnoredTemplate.format( 126 ignored_index=tg.GetJointStr(model.GetIgnoredOutputs(), method=lambda x: str(x.index)), 127 is_ignored_name=str(model.isIgnoredFunctionName)), file=model_fd) 128 129# Dump Model file for Cts tests 130def DumpCtsModel(model, model_fd): 131 assert model.compiled 132 if model.dumped: 133 return 134 print("void %s(Model *model) {"%(model.createFunctionName), file=model_fd) 135 136 # Phase 0: types 137 for t in model.GetTypes(): 138 if t.scale == 0.0 and t.zeroPoint == 0 and t.extraParams is None: 139 typeDef = "OperandType %s(Type::%s, %s);"%(t, t.type, t.GetDimensionsString()) 140 else: 141 if t.extraParams is None or t.extraParams.hide: 142 typeDef = "OperandType %s(Type::%s, %s, %s, %d);"%( 143 t, t.type, t.GetDimensionsString(), tg.PrettyPrintAsFloat(t.scale), t.zeroPoint) 144 else: 145 typeDef = "OperandType %s(Type::%s, %s, %s, %d, %s);"%( 146 t, t.type, t.GetDimensionsString(), tg.PrettyPrintAsFloat(t.scale), t.zeroPoint, 147 t.extraParams.GetConstructor()) 148 149 IndentedPrint(typeDef, file=model_fd) 150 151 # Phase 1: add operands 152 print(" // Phase 1, operands", file=model_fd) 153 for op in model.operands: 154 IndentedPrint("auto %s = model->addOperand(&%s);"%(op, op.type), file=model_fd) 155 156 # Phase 2: operations 157 print(" // Phase 2, operations", file=model_fd) 158 for p in model.GetParameters(): 159 paramDef = "static %s %s[] = %s;\nmodel->setOperandValue(%s, %s, sizeof(%s) * %d);"%( 160 p.type.GetCppTypeString(), p.initializer, p.GetListInitialization(), p, 161 p.initializer, p.type.GetCppTypeString(), p.type.GetNumberOfElements()) 162 IndentedPrint(paramDef, file=model_fd) 163 for op in model.operations: 164 IndentedPrint("model->addOperation(ANEURALNETWORKS_%s, {%s}, {%s});"%( 165 op.optype, tg.GetJointStr(op.ins), tg.GetJointStr(op.outs)), file=model_fd) 166 167 # Phase 3: add inputs and outputs 168 print (" // Phase 3, inputs and outputs", file=model_fd) 169 IndentedPrint("model->identifyInputsAndOutputs(\n {%s},\n {%s});"%( 170 tg.GetJointStr(model.GetInputs()), tg.GetJointStr(model.GetOutputs())), file=model_fd) 171 172 # Phase 4: set relaxed execution if needed 173 if (model.isRelaxed): 174 print (" // Phase 4: set relaxed execution", file=model_fd) 175 print (" model->relaxComputationFloat32toFloat16(true);", file=model_fd) 176 177 print (" assert(model->isValid());", file=model_fd) 178 print ("}\n", file=model_fd) 179 DumpCtsIsIgnored(model, model_fd) 180 model.dumped = True 181 182def DumpMixedType(operands, feedDict): 183 supportedTensors = [ 184 "DIMENSIONS", 185 "TENSOR_FLOAT32", 186 "TENSOR_INT32", 187 "TENSOR_QUANT8_ASYMM", 188 "TENSOR_OEM_BYTE", 189 "TENSOR_QUANT16_SYMM", 190 "TENSOR_FLOAT16", 191 "TENSOR_BOOL8", 192 "TENSOR_QUANT8_SYMM_PER_CHANNEL", 193 "TENSOR_QUANT16_ASYMM", 194 "TENSOR_QUANT8_SYMM", 195 ] 196 typedMap = {t: [] for t in supportedTensors} 197 FeedAndGet = lambda op, d: op.Feed(d).GetListInitialization() 198 # group the operands by type 199 for operand in operands: 200 try: 201 typedMap[operand.type.type].append(FeedAndGet(operand, feedDict)) 202 typedMap["DIMENSIONS"].append("{%d, {%s}}"%( 203 operand.index, GetJointStr(operand.dimensions))) 204 except KeyError as e: 205 traceback.print_exc() 206 sys.exit("Cannot dump tensor of type {}".format(operand.type.type)) 207 mixedTypeTemplate = """\ 208{{ // See tools/test_generator/include/TestHarness.h:MixedTyped 209 // int -> Dimensions map 210 .operandDimensions = {{{dimensions_map}}}, 211 // int -> FLOAT32 map 212 .float32Operands = {{{float32_map}}}, 213 // int -> INT32 map 214 .int32Operands = {{{int32_map}}}, 215 // int -> QUANT8_ASYMM map 216 .quant8AsymmOperands = {{{uint8_map}}}, 217 // int -> QUANT16_SYMM map 218 .quant16SymmOperands = {{{int16_map}}}, 219 // int -> FLOAT16 map 220 .float16Operands = {{{float16_map}}}, 221 // int -> BOOL8 map 222 .bool8Operands = {{{bool8_map}}}, 223 // int -> QUANT8_SYMM_PER_CHANNEL map 224 .quant8ChannelOperands = {{{int8_map}}}, 225 // int -> QUANT16_ASYMM map 226 .quant16AsymmOperands = {{{uint16_map}}}, 227 // int -> QUANT8_SYMM map 228 .quant8SymmOperands = {{{quant8_symm_map}}}, 229}}""" 230 return mixedTypeTemplate.format( 231 dimensions_map=tg.GetJointStr(typedMap.get("DIMENSIONS", [])), 232 float32_map=tg.GetJointStr(typedMap.get("TENSOR_FLOAT32", [])), 233 int32_map=tg.GetJointStr(typedMap.get("TENSOR_INT32", [])), 234 uint8_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT8_ASYMM", []) + 235 typedMap.get("TENSOR_OEM_BYTE", [])), 236 int16_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT16_SYMM", [])), 237 float16_map=tg.GetJointStr(typedMap.get("TENSOR_FLOAT16", [])), 238 int8_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT8_SYMM_PER_CHANNEL", [])), 239 bool8_map=tg.GetJointStr(typedMap.get("TENSOR_BOOL8", [])), 240 uint16_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT16_ASYMM", [])), 241 quant8_symm_map=tg.GetJointStr(typedMap.get("TENSOR_QUANT8_SYMM", [])) 242 ) 243 244# Dump Example file for Cts tests 245def DumpCtsExample(example, example_fd): 246 print("std::vector<MixedTypedExample>& get_%s() {" % (example.examplesName), file=example_fd) 247 print("static std::vector<MixedTypedExample> %s = {" % (example.examplesName), file=example_fd) 248 for inputFeedDict, outputFeedDict in example.feedDicts: 249 print ('// Begin of an example', file = example_fd) 250 print ('{\n.operands = {', file = example_fd) 251 inputs = DumpMixedType(example.model.GetInputs(), inputFeedDict) 252 outputs = DumpMixedType(example.model.GetOutputs(), outputFeedDict) 253 print ('//Input(s)\n%s,' % inputs , file = example_fd) 254 print ('//Output(s)\n%s' % outputs, file = example_fd) 255 print ('},', file = example_fd) 256 if example.expectedMultinomialDistributionTolerance is not None: 257 print ('.expectedMultinomialDistributionTolerance = %f' % 258 example.expectedMultinomialDistributionTolerance, file = example_fd) 259 print ('}, // End of an example', file = example_fd) 260 print("};", file=example_fd) 261 print("return %s;" % (example.examplesName), file=example_fd) 262 print("};\n", file=example_fd) 263 264# Dump Test file for Cts tests 265def DumpCtsTest(example, test_fd): 266 testTemplate = """\ 267TEST_F({test_case_name}, {test_name}) {{ 268 execute({namespace}::{create_model_name}, 269 {namespace}::{is_ignored_name}, 270 {namespace}::get_{examples_name}(){log_file});\n}}\n""" 271 if example.model.version is not None: 272 testTemplate += """\ 273TEST_AVAILABLE_SINCE({version}, {test_name}, {namespace}::{create_model_name})\n""" 274 print(testTemplate.format( 275 test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \ 276 else "GeneratedTests", 277 test_name=str(example.testName), 278 namespace=tg.FileNames.specName, 279 create_model_name=str(example.model.createFunctionName), 280 is_ignored_name=str(example.model.isIgnoredFunctionName), 281 examples_name=str(example.examplesName), 282 version=example.model.version, 283 log_file=tg.FileNames.logFile), file=test_fd) 284 285if __name__ == '__main__': 286 ParseCmdLine() 287 while tg.FileNames.NextFile(): 288 if Configuration.force_regenerate or NeedRegenerate(): 289 print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr) 290 exec(open(tg.FileNames.specFile, "r").read()) 291 print("Output CTS model: %s" % tg.FileNames.modelFile, file=sys.stderr) 292 print("Output example:%s" % tg.FileNames.exampleFile, file=sys.stderr) 293 print("Output CTS test: %s" % tg.FileNames.testFile, file=sys.stderr) 294 with SmartOpen(tg.FileNames.modelFile) as model_fd, \ 295 SmartOpen(tg.FileNames.exampleFile) as example_fd, \ 296 SmartOpen(tg.FileNames.testFile) as test_fd: 297 InitializeFiles(model_fd, example_fd, test_fd) 298 Example.DumpAllExamples( 299 DumpModel=DumpCtsModel, model_fd=model_fd, 300 DumpExample=DumpCtsExample, example_fd=example_fd, 301 DumpTest=DumpCtsTest, test_fd=test_fd) 302 else: 303 print("Skip file: %s" % tg.FileNames.specFile, file=sys.stderr) 304 with SmartOpen(tg.FileNames.ctsFile, mode="a") as cts_fd: 305 print("#include \"../generated/tests/%s.cpp\""%os.path.basename(tg.FileNames.specFile), 306 file=cts_fd) 307