• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2018, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Example generator
18
19Compiles spec files and generates the corresponding C++ TestModel definitions.
20Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh;
21See that script for details on how this script is used.
22
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28import os
29import sys
30import traceback
31
32import test_generator as tg
33
34MIN_PYTHON_VERSION = (3, 7)
35assert sys.version_info >= MIN_PYTHON_VERSION, "requires Python 3.7 or newer"
36
37# See ToCpp()
38COMMENT_KEY = "__COMMENT__"
39
40# Take a model from command line
41def ParseCmdLine():
42    parser = tg.ArgumentParser()
43    parser.add_argument("-e", "--example", help="the output example file or directory")
44    args = tg.ParseArgs(parser)
45    tg.FileNames.InitializeFileLists(args.spec, args.example)
46
47# Write headers for generated files, which are boilerplate codes only related to filenames
48def InitializeFiles(example_fd):
49    specFileBase = os.path.basename(tg.FileNames.specFile)
50    fileHeader = """\
51// Generated from {spec_file}
52// DO NOT EDIT
53// clang-format off
54#include "TestHarness.h"
55using namespace test_helper;  // NOLINT(google-build-using-namespace)
56"""
57    if example_fd is not None:
58        print(fileHeader.format(spec_file=specFileBase), file=example_fd)
59
60def IndentedStr(s, indent):
61    return ("\n" + " " * indent).join(s.split('\n'))
62
63def ToCpp(var, indent=0):
64    """Get the C++-style representation of a Python object.
65
66    For Python dictionary, it will be mapped to C++ struct aggregate initialization:
67        {
68            .key0 = value0,
69            .key1 = value1,
70            ...
71        }
72
73    For Python list, it will be mapped to C++ list initalization:
74        {value0, value1, ...}
75
76    In both cases, value0, value1, ... are stringified by invoking this method recursively.
77    """
78    if isinstance(var, dict):
79        if not var:
80            return "{}"
81        comment = var.get(COMMENT_KEY)
82        comment = "" if comment is None else " // %s" % comment
83        str_pair = lambda k, v: "    .%s = %s" % (k, ToCpp(v, indent + 4))
84        agg_init = "{%s\n%s\n}" % (comment,
85                                   ",\n".join(str_pair(k, var[k])
86                                              for k in var.keys()
87                                              if k != COMMENT_KEY))
88        return IndentedStr(agg_init, indent)
89    elif isinstance(var, (list, tuple)):
90        return "{%s}" % (", ".join(ToCpp(i, indent) for i in var))
91    elif type(var) is bool:
92        return "true" if var else "false"
93    elif type(var) is float:
94        return tg.PrettyPrintAsFloat(var)
95    else:
96        return str(var)
97
98def GetSymmPerChannelQuantParams(extraParams):
99    """Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams."""
100    if extraParams is None or extraParams.hide:
101        return {}
102    else:
103        return {"scales": extraParams.scales, "channelDim": extraParams.channelDim}
104
105def GetOperandStruct(operand):
106    """Get the dictionary that corresponds to test_helper::TestOperand."""
107    return {
108        COMMENT_KEY: operand.name,
109        "type": "TestOperandType::" + operand.type.type,
110        "dimensions": operand.type.dimensions,
111        "numberOfConsumers": len(operand.outs),
112        "scale": operand.type.scale,
113        "zeroPoint": operand.type.zeroPoint,
114        "lifetime": "TestOperandLifeTime::" + operand.lifetime,
115        "channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams),
116        "isIgnored": isinstance(operand, tg.IgnoredOutput),
117        "data": "TestBuffer::createFromVector<{cpp_type}>({data})".format(
118            cpp_type=operand.type.GetCppTypeString(),
119            data=operand.GetListInitialization(),
120        )
121    }
122
123def GetOperationStruct(operation):
124    """Get the dictionary that corresponds to test_helper::TestOperation."""
125    return {
126        "type": "TestOperationType::" + operation.optype,
127        "inputs": [op.model_index for op in operation.ins],
128        "outputs": [op.model_index for op in operation.outs],
129    }
130
131def GetSubgraphStruct(subgraph):
132    """Get the dictionary that corresponds to test_helper::TestSubgraph."""
133    return {
134        COMMENT_KEY: subgraph.name,
135        "operands": [GetOperandStruct(op) for op in subgraph.operands],
136        "operations": [GetOperationStruct(op) for op in subgraph.operations],
137        "inputIndexes": [op.model_index for op in subgraph.GetInputs()],
138        "outputIndexes": [op.model_index for op in subgraph.GetOutputs()],
139    }
140
141def GetModelStruct(example):
142    """Get the dictionary that corresponds to test_helper::TestModel."""
143    return {
144        "main": GetSubgraphStruct(example.model),
145        "referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()],
146        "isRelaxed": example.model.isRelaxed,
147        "expectedMultinomialDistributionTolerance":
148                example.expectedMultinomialDistributionTolerance,
149        "expectFailure": example.expectFailure,
150        "minSupportedVersion": "TestHalVersion::%s" % (
151                example.model.version if example.model.version is not None else "UNKNOWN"),
152    }
153
154def DumpExample(example, example_fd):
155    assert example.model.compiled
156    template = """\
157namespace generated_tests::{spec_name} {{
158
159const TestModel& get_{example_name}() {{
160    static TestModel model = {aggregate_init};
161    return model;
162}}
163
164const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}());
165
166}}  // namespace generated_tests::{spec_name}
167"""
168    print(template.format(
169            spec_name=tg.FileNames.specName,
170            test_name=str(example.testName),
171            example_name=str(example.examplesName),
172            aggregate_init=ToCpp(GetModelStruct(example), indent=4),
173        ), file=example_fd)
174
175
176if __name__ == '__main__':
177    ParseCmdLine()
178    tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample)
179