• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""VTS testcase generator
17
18Implements VTS test backend. Shares most logic with the CTS test
19generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh;
20See that script for details on how this script is used.
21
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27import argparse
28from functools import reduce
29import math
30import os
31import struct
32import sys
33import contextlib
34import test_generator
35import pprint
36# Stuff from test generator
37from test_generator import Configuration
38from test_generator import Example
39from test_generator import Float32Scalar
40from test_generator import IgnoredOutput
41from test_generator import Input
42from test_generator import Int32Scalar
43from test_generator import Internal
44from test_generator import Model
45from test_generator import Operand
46from test_generator import Output
47from test_generator import Parameter
48from test_generator import smart_open
49
50# Take a model from command line
51def import_source():
52  parser = argparse.ArgumentParser()
53  parser.add_argument("spec", help="the spec file")
54  parser.add_argument(
55      "-m", "--model", help="the output model file", default="-")
56  parser.add_argument(
57      "-e", "--example", help="the output example file", default="-")
58  args = parser.parse_args()
59
60  if os.path.exists(args.spec):
61    test_generator.FileNames.SpecFile = os.path.basename(args.spec)
62    exec (open(args.spec).read())
63
64  return (args.model, args.example)
65
66# Generate operands in VTS format
67def generate_vts_operands():
68  # Dump operand definitions
69  op_def = """\
70        {{
71            .type = OperandType::{operand_type},
72            .dimensions = {shape},
73            .numberOfConsumers = {no_consumers},
74            .scale = {scale},
75            .zeroPoint = {zero_point},
76            .lifetime = OperandLifeTime::{lifetime},
77            .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},
78        }}"""
79  offset = 0
80  op_definitions = []
81  for o in Operand.operands.objects():
82    ty = o.type
83    no_consumers = len(o.outs) if o.traversable() else 0
84    lifetime = o.lifetime()
85    length = ty.get_size() if o.is_weight() else 0
86    real_shape, scale, zero_point = ty.get_parsed_shape()
87    scale = float(scale)
88    zero_point = int(zero_point)
89    op = {
90        "operand_type": ty.get_element_type(),
91        "shape": "{%s}" % real_shape,
92        "no_consumers": no_consumers,
93        "scale": test_generator.pretty_print_as_float(scale),
94        "zero_point": str(int(zero_point)),
95        "lifetime": lifetime,
96        "offset": offset if o.is_weight() else 0,
97        "length": length
98    }
99    offset += length
100    op_definitions.append(op_def.format(**op))
101
102  op_vec = """\
103    const std::vector<Operand> operands = {{
104{0}
105    }};""".format(",\n".join(op_definitions))
106  return op_vec
107
108# Generate VTS operand values
109def generate_vts_operand_values():
110  weights = [o for o in Operand.operands.objects() if o.is_weight()]
111  binit = []
112  for w in weights:
113    ty = w.type.get_element_type()
114    if ty == "TENSOR_QUANT8_ASYMM":
115      binit += w.initializer
116    elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32"}:
117      fmt = "f" if (ty == "TENSOR_FLOAT32" or ty == "FLOAT32") else "i"
118      for f in w.initializer:
119        binit += [int(x) for x in struct.pack(fmt, f)]
120    else:
121      assert 0 and "Unsupported VTS operand type"
122
123  init_defs = ", ".join([str(x) for x in binit])
124  if (init_defs != ""):
125    init_defs = "\n      %s\n    " % init_defs
126  byte_vec_fmt = """{%s}""" % init_defs
127  return byte_vec_fmt
128
129# Generate VTS operations
130class VTSOps(object):
131  vts_ops = []
132  def generate_vts_operation(op):
133    try:
134      opcode =op.optype
135    except AttributeError: # not an op, but things like weights
136      return
137    op_fmt = """\
138        {{
139            .type = OperationType::{op_code},
140            .inputs = {{{ins}}},
141            .outputs = {{{outs}}},
142        }}"""
143    op_content = {
144        'op_code': op.optype,
145        'op_type': op.type.get_element_type(),
146        'ins': ", ".join([str(x.ID()) for x in op.ins]),
147        'outs': ", ".join([str(x.ID()) for x in op.outs]),
148    }
149    VTSOps.vts_ops.append(op_fmt.format(**op_content))
150    return True
151
152def generate_vts_operations(model_file):
153  test_generator.TopologicalSort(lambda x: VTSOps.generate_vts_operation(x))
154  return ",\n".join(VTSOps.vts_ops)
155
156
157def generate_vts_model(model_file):
158  operand_values_fmt = ""
159  if Configuration.useSHM():
160    # Boilerplate code for passing weights in shared memory
161    operand_values_fmt = """\
162    std::vector<uint8_t> operandValues = {{}};
163    const uint8_t data[] = {operand_values};
164
165    // Allocate segment of android shared memory, wrapped in hidl_memory.
166    // This object will be automatically freed when sharedMemory is destroyed.
167    hidl_memory sharedMemory = allocateSharedMemory(sizeof(data));
168
169    // Mmap ashmem into usable address and hold it within the mappedMemory object.
170    // MappedMemory will automatically munmap the memory when it is destroyed.
171    sp<IMemory> mappedMemory = mapMemory(sharedMemory);
172
173    if (mappedMemory != nullptr) {{
174        // Retrieve the mmapped pointer.
175        uint8_t* mappedPointer =
176                static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer()));
177
178        if (mappedPointer != nullptr) {{
179            // Acquire the write lock for the shared memory segment, upload the data,
180            // and release the lock.
181            mappedMemory->update();
182            std::copy(data, data + sizeof(data), mappedPointer);
183            mappedMemory->commit();
184        }}
185    }}
186
187    const std::vector<hidl_memory> pools = {{sharedMemory}};
188"""
189  else:
190    # Passing weights via operandValues
191    operand_values_fmt = """\
192    std::vector<uint8_t> operandValues = {operand_values};
193    const std::vector<hidl_memory> pools = {{}};
194"""
195
196  operand_values_val = {
197      'operand_values': generate_vts_operand_values()
198  }
199  operand_values = operand_values_fmt.format(**operand_values_val)
200  #  operand_values = operand_values_fmt
201  model_fmt = """\
202// Generated code. Do not edit
203// Create the model
204Model createTestModel() {{
205{operand_decls}
206
207    const std::vector<Operation> operations = {{
208{operations}
209    }};
210
211    const std::vector<uint32_t> inputIndexes = {{{input_indices}}};
212    const std::vector<uint32_t> outputIndexes = {{{output_indices}}};
213{operand_values}
214    return {{
215        .operands = operands,
216        .operations = operations,
217        .inputIndexes = inputIndexes,
218        .outputIndexes = outputIndexes,
219        .operandValues = operandValues,
220        .pools = pools,{relaxed_field}
221    }};
222}}
223"""
224  model = {
225      "operations": generate_vts_operations(sys.stdout),
226      "operand_decls": generate_vts_operands(),
227      "operand_values": operand_values,
228      "output_indices": ", ".join([str(i.ID()) for i in Output.get_outputs()]),
229      "input_indices": ", ".join([str(i.ID()) for i in Input.get_inputs(True)]),
230      "relaxed_field":
231        "\n        .relaxComputationFloat32toFloat16 = true," if (Model.isRelaxed()) else ""
232  }
233  print(model_fmt.format(**model), file = model_file)
234
235def generate_vts(model_file):
236  generate_vts_model(model_file)
237  print (IgnoredOutput.gen_ignored(), file=model_file)
238
239if __name__ == "__main__":
240  (model, example) = import_source()
241  print("Output VTS model: %s" % model, file=sys.stderr)
242  print("Output example:" + example, file=sys.stderr)
243
244  with smart_open(model) as model_file:
245    generate_vts(model_file)
246  with smart_open(example) as example_file:
247    Example.dump(example_file)
248