1#!/usr/bin/python3 2 3# Copyright 2018, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Example generator 18 19Compiles spec files and generates the corresponding C++ TestModel definitions. 20Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh; 21See that script for details on how this script is used. 22 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28import os 29import sys 30import traceback 31 32import test_generator as tg 33 34MIN_PYTHON_VERSION = (3, 7) 35assert sys.version_info >= MIN_PYTHON_VERSION, "requires Python 3.7 or newer" 36 37# See ToCpp() 38COMMENT_KEY = "__COMMENT__" 39 40# Take a model from command line 41def ParseCmdLine(): 42 parser = tg.ArgumentParser() 43 parser.add_argument("-e", "--example", help="the output example file or directory") 44 args = tg.ParseArgs(parser) 45 tg.FileNames.InitializeFileLists(args.spec, args.example) 46 47# Write headers for generated files, which are boilerplate codes only related to filenames 48def InitializeFiles(example_fd): 49 specFileBase = os.path.basename(tg.FileNames.specFile) 50 fileHeader = """\ 51// Generated from {spec_file} 52// DO NOT EDIT 53// clang-format off 54#include "TestHarness.h" 55using namespace test_helper; // NOLINT(google-build-using-namespace) 56""" 57 if example_fd is not None: 58 print(fileHeader.format(spec_file=specFileBase), file=example_fd) 59 60def IndentedStr(s, indent): 61 return ("\n" + " " * indent).join(s.split('\n')) 62 63def ToCpp(var, indent=0): 64 """Get the C++-style representation of a Python object. 65 66 For Python dictionary, it will be mapped to C++ struct aggregate initialization: 67 { 68 .key0 = value0, 69 .key1 = value1, 70 ... 71 } 72 73 For Python list, it will be mapped to C++ list initalization: 74 {value0, value1, ...} 75 76 In both cases, value0, value1, ... are stringified by invoking this method recursively. 77 """ 78 if isinstance(var, dict): 79 if not var: 80 return "{}" 81 comment = var.get(COMMENT_KEY) 82 comment = "" if comment is None else " // %s" % comment 83 str_pair = lambda k, v: " .%s = %s" % (k, ToCpp(v, indent + 4)) 84 agg_init = "{%s\n%s\n}" % (comment, 85 ",\n".join(str_pair(k, var[k]) 86 for k in var.keys() 87 if k != COMMENT_KEY)) 88 return IndentedStr(agg_init, indent) 89 elif isinstance(var, (list, tuple)): 90 return "{%s}" % (", ".join(ToCpp(i, indent) for i in var)) 91 elif type(var) is bool: 92 return "true" if var else "false" 93 elif type(var) is float: 94 return tg.PrettyPrintAsFloat(var) 95 else: 96 return str(var) 97 98def GetSymmPerChannelQuantParams(extraParams): 99 """Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams.""" 100 if extraParams is None or extraParams.hide: 101 return {} 102 else: 103 return {"scales": extraParams.scales, "channelDim": extraParams.channelDim} 104 105def GetOperandStruct(operand): 106 """Get the dictionary that corresponds to test_helper::TestOperand.""" 107 return { 108 COMMENT_KEY: operand.name, 109 "type": "TestOperandType::" + operand.type.type, 110 "dimensions": operand.type.dimensions, 111 "numberOfConsumers": len(operand.outs), 112 "scale": operand.type.scale, 113 "zeroPoint": operand.type.zeroPoint, 114 "lifetime": "TestOperandLifeTime::" + operand.lifetime, 115 "channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams), 116 "isIgnored": isinstance(operand, tg.IgnoredOutput), 117 "data": "TestBuffer::createFromVector<{cpp_type}>({data})".format( 118 cpp_type=operand.type.GetCppTypeString(), 119 data=operand.GetListInitialization(), 120 ) 121 } 122 123def GetOperationStruct(operation): 124 """Get the dictionary that corresponds to test_helper::TestOperation.""" 125 return { 126 "type": "TestOperationType::" + operation.optype, 127 "inputs": [op.model_index for op in operation.ins], 128 "outputs": [op.model_index for op in operation.outs], 129 } 130 131def GetSubgraphStruct(subgraph): 132 """Get the dictionary that corresponds to test_helper::TestSubgraph.""" 133 return { 134 COMMENT_KEY: subgraph.name, 135 "operands": [GetOperandStruct(op) for op in subgraph.operands], 136 "operations": [GetOperationStruct(op) for op in subgraph.operations], 137 "inputIndexes": [op.model_index for op in subgraph.GetInputs()], 138 "outputIndexes": [op.model_index for op in subgraph.GetOutputs()], 139 } 140 141def GetModelStruct(example): 142 """Get the dictionary that corresponds to test_helper::TestModel.""" 143 return { 144 "main": GetSubgraphStruct(example.model), 145 "referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()], 146 "isRelaxed": example.model.isRelaxed, 147 "expectedMultinomialDistributionTolerance": 148 example.expectedMultinomialDistributionTolerance, 149 "expectFailure": example.expectFailure, 150 "minSupportedVersion": "TestHalVersion::%s" % ( 151 example.model.version if example.model.version is not None else "UNKNOWN"), 152 } 153 154def DumpExample(example, example_fd): 155 assert example.model.compiled 156 template = """\ 157namespace generated_tests::{spec_name} {{ 158 159const TestModel& get_{example_name}() {{ 160 static TestModel model = {aggregate_init}; 161 return model; 162}} 163 164const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}()); 165 166}} // namespace generated_tests::{spec_name} 167""" 168 print(template.format( 169 spec_name=tg.FileNames.specName, 170 test_name=str(example.testName), 171 example_name=str(example.examplesName), 172 aggregate_init=ToCpp(GetModelStruct(example), indent=4), 173 ), file=example_fd) 174 175 176if __name__ == '__main__': 177 ParseCmdLine() 178 tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample) 179