• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""VTS testcase generator
17
18Implements VTS test backend. Shares most logic with the CTS test
19generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh;
20See that script for details on how this script is used.
21
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27import argparse
28from functools import reduce
29import math
30import numpy as np
31import os
32import re
33import struct
34import sys
35import contextlib
36import pprint
37
38# Stuff from test generator
39import test_generator as tg
40from test_generator import ActivationConverter
41from test_generator import BoolScalar
42from test_generator import Configuration
43from test_generator import DataTypeConverter
44from test_generator import DataLayoutConverter
45from test_generator import Example
46from test_generator import Float16Scalar
47from test_generator import Float32Scalar
48from test_generator import Float32Vector
49from test_generator import IgnoredOutput
50from test_generator import Input
51from test_generator import Int32Scalar
52from test_generator import Int32Vector
53from test_generator import Internal
54from test_generator import Model
55from test_generator import Operand
56from test_generator import Output
57from test_generator import Parameter
58from test_generator import ParameterAsInputConverter
59from test_generator import RelaxedModeConverter
60from test_generator import SmartOpen
61from test_generator import SymmPerChannelQuantParams
62
63# Dumping methods that shared with CTS generator
64from cts_generator import DumpCtsExample
65from cts_generator import DumpCtsIsIgnored
66
67# Take a model from command line
68def ParseCmdLine():
69    parser = argparse.ArgumentParser()
70    parser.add_argument("spec", help="the spec file")
71    parser.add_argument(
72        "-m", "--model", help="the output model file", default="-")
73    parser.add_argument(
74        "-e", "--example", help="the output example file", default="-")
75    parser.add_argument(
76        "-t", "--test", help="the output test file", default="-")
77    args = parser.parse_args()
78    tg.FileNames.InitializeFileLists(
79        args.spec, args.model, args.example, args.test)
80
81# Generate operands in VTS format
82def generate_vts_operands(model):
83  # Dump operand definitions
84  op_def = """\
85        {{
86            .type = OperandType::{operand_type},
87            .dimensions = {shape},
88            .numberOfConsumers = {no_consumers},
89            .scale = {scale},
90            .zeroPoint = {zero_point},
91            .lifetime = OperandLifeTime::{lifetime},
92            .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},{extraParams}
93        }}"""
94  offset = 0
95  op_definitions = []
96  extra_params_definitions = []
97  for index, o in enumerate(model.operands):
98    length = o.type.GetByteSize() if isinstance(o, Parameter) else 0
99    add_extra_params = o.type.extraParams is not None and not o.type.extraParams.hide
100    op = {
101        "operand_type": o.type.type,
102        "shape": o.type.GetDimensionsString(),
103        "no_consumers": len(o.outs),
104        "scale": tg.PrettyPrintAsFloat(o.type.scale),
105        "zero_point": str(int(o.type.zeroPoint)),
106        "lifetime": o.lifetime,
107        "offset": offset if isinstance(o, Parameter) else 0,
108        "length": length,
109        "extraParams": "" if not add_extra_params else "\n            .extraParams = std::move(extraParams%d)," % (index,),
110    }
111    offset += length
112    op_definitions.append(op_def.format(**op))
113
114    extra_params_def = """\
115    Operand::ExtraParams extraParams{index};
116    extraParams{index}.{setMethodName}({param});
117"""
118
119    if add_extra_params:
120      ep = o.type.extraParams
121      op = {
122          "index": index,
123          "setMethodName": ep.GetVtsSetter(),
124          "param": ep.GetVtsConstructor(),
125      }
126      extra_params_definitions.append(extra_params_def.format(**op))
127
128  op_vec = """{0}\
129    const std::vector<Operand> operands = {{
130{1}
131    }};""".format(",\n".join(extra_params_definitions), ",\n".join(op_definitions))
132  return op_vec
133
134# Generate VTS operand values
135def generate_vts_operand_values(operands):
136    weights = [o for o in operands if isinstance(o, Parameter)]
137    binit = []
138    for w in weights:
139        ty = w.type.type
140        if ty == "TENSOR_QUANT8_ASYMM":
141            binit += w.value
142        elif ty == "TENSOR_QUANT8_SYMM_PER_CHANNEL" or ty == "TENSOR_QUANT8_SYMM":
143            binit += [struct.pack("b", value)[0] for value in w.value]
144        elif ty == "BOOL" or ty == "TENSOR_BOOL8":
145            binit += [1 if x else 0 for x in w.value]
146        elif ty == "TENSOR_FLOAT16" or ty == "FLOAT16":
147            for f in w.value:
148                # The pack format for float16 is not available until Python 3.6.
149                binit += [int(x) for x in np.float16(f).tostring()]
150        elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32", "TENSOR_QUANT16_ASYMM"}:
151            if ty in ["TENSOR_FLOAT32", "FLOAT32"]:
152                fmt = "f"
153            elif ty in ["TENSOR_INT32", "INT32"]:
154                fmt = "i"
155            elif ty == "TENSOR_QUANT16_ASYMM":
156                fmt = "H"
157            for f in w.value:
158                binit += [int(x) for x in struct.pack(fmt, f)]
159        else:
160            assert 0 and "Unsupported VTS operand type"
161
162    init_defs = ", ".join([str(x) for x in binit])
163    if (init_defs != ""):
164        init_defs = "\n      %s\n    " % init_defs
165    byte_vec_fmt = """{%s}""" % init_defs
166    return byte_vec_fmt
167
168# Generate VTS operations
169def generate_vts_operation(op, model):
170    op_fmt = """\
171        {{
172            .type = OperationType::{op_code},
173            .inputs = {{{ins}}},
174            .outputs = {{{outs}}},
175        }}"""
176    op_content = {
177        'op_code': op.optype,
178        'ins': tg.GetJointStr(model.GetIndexOfOperands(op.ins)),
179        'outs': tg.GetJointStr(model.GetIndexOfOperands(op.outs))
180    }
181    return op_fmt.format(**op_content)
182
183def generate_vts_operations(model):
184    vts_ops = [generate_vts_operation(op, model) for op in model.operations]
185    return ",\n".join(vts_ops)
186
187def generate_vts_model(model, model_file):
188  operand_values_fmt = ""
189  if Configuration.useSHM():
190    # Boilerplate code for passing weights in shared memory
191    operand_values_fmt = """\
192    std::vector<uint8_t> operandValues = {{}};
193    const uint8_t data[] = {operand_values};
194
195    // Allocate segment of android shared memory, wrapped in hidl_memory.
196    // This object will be automatically freed when sharedMemory is destroyed.
197    hidl_memory sharedMemory = allocateSharedMemory(sizeof(data));
198
199    // Mmap ashmem into usable address and hold it within the mappedMemory object.
200    // MappedMemory will automatically munmap the memory when it is destroyed.
201    sp<IMemory> mappedMemory = mapMemory(sharedMemory);
202
203    if (mappedMemory != nullptr) {{
204        // Retrieve the mmapped pointer.
205        uint8_t* mappedPointer =
206                static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer()));
207
208        if (mappedPointer != nullptr) {{
209            // Acquire the write lock for the shared memory segment, upload the data,
210            // and release the lock.
211            mappedMemory->update();
212            std::copy(data, data + sizeof(data), mappedPointer);
213            mappedMemory->commit();
214        }}
215    }}
216
217    const std::vector<hidl_memory> pools = {{sharedMemory}};
218"""
219  else:
220    # Passing weights via operandValues
221    operand_values_fmt = """\
222    std::vector<uint8_t> operandValues = {operand_values};
223    const std::vector<hidl_memory> pools = {{}};
224"""
225
226  operand_values_val = {
227      'operand_values': generate_vts_operand_values(model.operands)
228  }
229  operand_values = operand_values_fmt.format(**operand_values_val)
230  #  operand_values = operand_values_fmt
231  model_fmt = """\
232// Create the model
233Model {create_test_model_name}() {{
234{operand_decls}
235
236    const std::vector<Operation> operations = {{
237{operations}
238    }};
239
240    const std::vector<uint32_t> inputIndexes = {{{input_indices}}};
241    const std::vector<uint32_t> outputIndexes = {{{output_indices}}};
242{operand_values}
243    return {{
244        .operands = operands,
245        .operations = operations,
246        .inputIndexes = inputIndexes,
247        .outputIndexes = outputIndexes,
248        .operandValues = operandValues,
249        .pools = pools,{relaxed_field}
250    }};
251}}
252"""
253  model_dict = {
254      "create_test_model_name": str(model.createTestFunctionName),
255      "operations": generate_vts_operations(model),
256      "operand_decls": generate_vts_operands(model),
257      "operand_values": operand_values,
258      "output_indices": tg.GetJointStr(model.GetOutputsIndex()),
259      "input_indices": tg.GetJointStr(model.GetInputsIndex()),
260      "relaxed_field":
261        "\n        .relaxComputationFloat32toFloat16 = true," if (model.isRelaxed) else ""
262  }
263  print(model_fmt.format(**model_dict), file = model_file)
264
265def generate_vts(model, model_file):
266  assert model.compiled
267  generate_vts_model(model, model_file)
268  DumpCtsIsIgnored(model, model_file)
269
270def generate_vts_test(example, test_file):
271    testTemplate = """\
272TEST_F({test_case_name}, {test_name}) {{
273  generated_tests::Execute(device,
274                           {namespace}::{create_model_name},
275                           {namespace}::{is_ignored_name},
276                           {namespace}::get_{examples_name}(){test_dynamic_output_shape});\n}}
277
278TEST_F(ValidationTest, {test_name}) {{
279  const Model model = {namespace}::{create_model_name}();
280  const std::vector<Request> requests = createRequests({namespace}::get_{examples_name}());
281  validateEverything(model, requests);
282}}\n
283"""
284    if example.model.hasDynamicOutputShape:
285        print("#ifdef NN_TEST_DYNAMIC_OUTPUT_SHAPE", file=test_fd)
286    print(testTemplate.format(
287            test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \
288                           else "NeuralnetworksHidlTest",
289            test_name=str(example.testName),
290            namespace=tg.FileNames.specName,
291            create_model_name=str(example.model.createTestFunctionName),
292            is_ignored_name=str(example.model.isIgnoredFunctionName),
293            examples_name=str(example.examplesName),
294            test_dynamic_output_shape=", true" if example.model.hasDynamicOutputShape else ""
295        ), file=test_fd)
296    if example.model.hasDynamicOutputShape:
297        print("#endif", file=test_fd)
298
299def InitializeFiles(model_fd, example_fd, test_fd):
300    fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit"
301    testFileHeader = """\
302// Generated from: {spec_file}.
303namespace {spec_name} {{
304// Generated {spec_name} test
305#include "{example_file}"
306// Generated model constructor
307#include "{model_file}"
308}} // namespace {spec_name}\n"""
309    # This regex is to remove prefix and get relative path for #include
310    pathRegex = r".*frameworks/ml/nn/(runtime/test/generated/)?"
311    specFileBase = os.path.basename(tg.FileNames.specFile)
312    print(fileHeader.format(spec_file=specFileBase), file=model_fd)
313    print(fileHeader.format(spec_file=specFileBase), file=example_fd)
314    print(testFileHeader.format(
315        spec_file=specFileBase,
316        model_file=re.sub(pathRegex, "", tg.FileNames.modelFile),
317        example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile),
318        spec_name=tg.FileNames.specName), file=test_fd)
319
320if __name__ == "__main__":
321    ParseCmdLine()
322    while tg.FileNames.NextFile():
323        print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr)
324        exec (open(tg.FileNames.specFile, "r").read())
325        print("Output VTS model: %s" % tg.FileNames.modelFile, file=sys.stderr)
326        print("Output example:" + tg.FileNames.exampleFile, file=sys.stderr)
327        with SmartOpen(tg.FileNames.modelFile) as model_fd, \
328             SmartOpen(tg.FileNames.exampleFile) as example_fd, \
329             SmartOpen(tg.FileNames.testFile, mode="a") as test_fd:
330            InitializeFiles(model_fd, example_fd, test_fd)
331            Example.DumpAllExamples(
332                DumpModel=generate_vts, model_fd=model_fd,
333                DumpExample=DumpCtsExample, example_fd=example_fd,
334                DumpTest=generate_vts_test, test_fd=test_fd)
335