• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""NN model compiler
18
19Contain classes definition and utilify functions for compiling models and
20examples into NDK-based CTS and VTS unit tests.
21
22Used by example_generator.py and spec_visualizer.py
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28import copy
29from functools import reduce
30import argparse
31import io
32import itertools
33import os
34import re
35import sys
36import traceback
37import numpy as np
38
39def GetJointStr(l, sep=", ", method=str):
40    return sep.join([method(i) for i in l])
41
42# Print in C float literal format
43def PrettyPrintAsFloat(x):
44    s = str(float(x))
45    if s.find(".") >= 0 or s.find("e") >= 0:
46        return s + "f"
47    else:
48        return s + ".0f"
49
50# Transform from original type to float32
51def Dequantize(v, ty):
52    v -= ty.zeroPoint
53    if ty.scale != 0:
54        v *= ty.scale
55    if isinstance(ty.extraParams, SymmPerChannelQuantParams):
56        v *= ty.extraParams.GetScalesBroadcastArray(ty.dimensions)
57    return v
58
59# Transform float32 to target data type
60def Quantize(v, ty):
61    if ty.scale != 0:
62        v /= ty.scale
63    if isinstance(ty.extraParams, SymmPerChannelQuantParams):
64        v = v / ty.extraParams.GetScalesBroadcastArray(ty.dimensions)
65    v += ty.zeroPoint
66    if not ty.IsFloat():
67        v = np.round(v)
68        v = v.astype(int)
69
70    if ty.type == "TENSOR_QUANT8_ASYMM":
71        v = np.minimum(np.maximum(v, 0), 255)
72    elif ty.type == "TENSOR_QUANT16_ASYMM":
73        v = np.minimum(np.maximum(v, 0), 65535)
74    elif ty.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL":
75        v = np.minimum(np.maximum(v, -127), 127)
76    elif ty.type == "UINT32":
77        v = np.maximum(v, 0)
78    elif ty.type == "TENSOR_QUANT8_ASYMM_SIGNED":
79        v = np.minimum(np.maximum(v, -128), 127)
80    return v
81
82# Tracking objects inside a model with a unique name
83class NamedObject:
84    existingNames = set()
85
86    def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False):
87        name = GetJointStr([i for i in args if i is not None and i != ""], sep=sep)
88        if skipRenaming:
89            self.name = name
90            return
91        # make the name unique by renaming with a suffix number
92        uniqueName = name if showZero is False else name + sep + str(startsFrom)
93        while uniqueName in self.__class__.existingNames:
94            startsFrom += 1
95            uniqueName = name + sep + str(startsFrom)
96        self.__class__.existingNames.add(uniqueName)
97        self.name = uniqueName
98
99    def __str__(self):
100        return self.name
101    __repr__ = __str__
102
103    # Since names are unique, objects with the same name are considered equal
104    def __eq__(self, other):
105        return isinstance(other, NamedObject) and self.name == other.name
106
107    def __ne__(self, other):
108        return not self.__eq__(other)
109
110    def __hash__(self):
111        return hash(self.name)
112
113    def __lt__(self, other):
114        return self.name < other.name
115
116# Types, operands should all have a unique name since they share the same namespace
117class NamedVariable(NamedObject):
118    existingNames = set()
119    def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False):
120        NamedObject.__init__(self, *args, sep=sep, showZero=showZero,
121            startsFrom=startsFrom, skipRenaming=skipRenaming)
122
123# Global variables in the spec namespace such as CreateModel, is_ignored, and examples
124class GlobalVariable(NamedVariable):
125    def __init__(self, *args, skipRenaming=False):
126        NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming)
127
128# Each test should have a unique name, but will not conflict with variables
129class NamedTest(NamedObject):
130    existingNames = set()
131    def __init__(self, *args, startsFrom=0, skipRenaming=False):
132        NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming)
133
134class Type(NamedVariable):
135    typesMap = dict()
136    typeLookup = {
137        "INT32": "int32_t",
138        "UINT32": "uint32_t",
139        "FLOAT32": "float",
140        "FLOAT16": "_Float16",
141        "TENSOR_INT32": "int32_t",
142        "TENSOR_FLOAT16": "_Float16",
143        "TENSOR_FLOAT32": "float",
144        "TENSOR_QUANT8_ASYMM": "uint8_t",
145        "TENSOR_QUANT8_SYMM": "int8_t",
146        "BOOL": "bool8",
147        "TENSOR_QUANT16_ASYMM": "uint16_t",
148        "TENSOR_QUANT16_SYMM": "int16_t",
149        "TENSOR_BOOL8": "bool8",
150        "TENSOR_QUANT8_SYMM_PER_CHANNEL": "int8_t",
151        "TENSOR_QUANT8_ASYMM_SIGNED": "int8_t",
152    #     "OEM_SCALAR": this is service-defined.
153        "TENSOR_OEM_BYTE": "uint8_t",
154        "SUBGRAPH": "uint32_t",  # Index into TestModel::referenced.
155    }
156
157    # types are named as "type0", "type1", ...
158    def __init__(self, vt, dimensions, scale, zeroPoint, name="type", skipRenaming=False,
159                 extraParams=None):
160        NamedVariable.__init__(self, name, sep="", showZero=True, skipRenaming=skipRenaming)
161        self.type = vt
162        self.dimensions = dimensions
163        self.scale = float(scale)
164        self.zeroPoint = int(zeroPoint)
165        self.extraParams = extraParams
166
167    # Factory for Type object, only create a new Type if requested type does
168    # not have a match with all existing types
169    @staticmethod
170    def GetType(vt, dimensions, scale=0, zeroPoint=0, extraParams=None):
171        assert isinstance(dimensions, (list, tuple)), \
172            'dimensions must be a list or tuple, got {}'.format(type(dimensions))
173        key = ",".join([vt, str(dimensions), str(scale), str(zeroPoint), str(extraParams)])
174        if key not in Type.typesMap:
175            Type.typesMap[key] = Type(vt, dimensions, scale, zeroPoint, extraParams=extraParams)
176        return Type.typesMap[key]
177
178    @staticmethod
179    def GetAllTypes():
180        # sort to ensure a stable order when dumping the code
181        return sorted(Type.typesMap.values())
182
183    # For backward-compatibility
184    @staticmethod
185    def GetTypeFromString(vt, shape, extraParams=None):
186        dimensions, scale, zeroPoint = Type.GetParsedShape(shape)
187        scale = float(scale)
188        zeroPoint = int(zeroPoint)
189        return Type.GetType(vt, dimensions, scale, zeroPoint, extraParams)
190
191    # For backward-compatibility
192    @staticmethod
193    def GetParsedShape(shape):
194        # Parse shape
195        if (shape != "" and shape != "{}"):
196            left, sep, right = shape.partition('{')
197            real_shape, sep, right = right.partition('}')
198            shape = [int(x) for x in real_shape.split(",")]
199            # left now looks like "0.0f, 127.5f, "
200            scale, sep, zero_point = right.rpartition(',')
201            if scale == "":
202                if zero_point == "":
203                    return shape, "0", "0"
204                return shape, zero_point, "0"
205            left, sep, scale = scale.partition(',')
206            return shape, scale.replace("f", ""), zero_point
207        else:
208            return [], "0", "0"
209
210    def GetNumberOfElements(self):
211        return reduce(lambda x,y: x*y, self.dimensions, 1)
212
213    def GetCppTypeString(self):
214        return Type.typeLookup[self.type]
215
216    def IsFloat(self):
217        return self.GetCppTypeString() in ["float", "_Float16"]
218
219    def IsBool(self):
220        return self.GetCppTypeString() == "bool8"
221
222    def IsScalar(self):
223        return not self.type.startswith("TENSOR_")
224
225    def GetElementByteSize(self):
226        cppTypeString = self.GetCppTypeString()
227        if cppTypeString in ["uint8_t", "int8_t", "bool8"]:
228            return 1
229        elif cppTypeString in ["int16_t", "uint16_t", "_Float16"]:
230            return 2
231        else:
232            return 4
233
234    def GetByteSize(self):
235        return self.GetElementByteSize() * self.GetNumberOfElements()
236
237    def GetDimensionsString(self):
238        return "{" + GetJointStr(self.dimensions) + "}"
239
240    def GetSignatureTuple(self):
241        return (self.type, self.dimensions, self.scale, self.zeroPoint)
242
243    def ToUnspecifiedDim(self):
244        return Type.GetType(self.type, [0] * len(self.dimensions), self.scale, self.zeroPoint)
245
246# To track implicitly convertible parameter types
247class ImplicitParameter():
248    @staticmethod
249    def ImplicitConvertion(value):
250        if isinstance(value, Operand):
251            return value
252        for implicitType in ImplicitParameter.__subclasses__():
253            if implicitType.IsCompatible(value):
254                return implicitType("param", value)
255        assert False, "%s not supported for implicit parameter"%value
256
257
258# ExtraParams with per-channel quantization.
259class SymmPerChannelQuantParams():
260    def __init__(self, channelDim, scales, hide = False):
261        self.channelDim = channelDim
262        self.scales = scales
263        self.hide = hide
264
265    def GetScalesBroadcastArray(self, dimensions):
266        bshape = [1] * len(dimensions)
267        bshape[self.channelDim] = len(self.scales)
268        return np.array(self.scales).reshape(bshape)
269
270    def GetConstructor(self):
271        return "SymmPerChannelQuantParams({%s},%d)" % (
272            ", ".join(str(x) + "f" for x in self.scales), self.channelDim)
273
274    def GetVtsSetter(self):
275        return "channelQuant"
276
277    def GetVtsConstructor(self):
278        return "SymmPerChannelQuantParams{.scales={%s}, .channelDim=%d}" % (
279            ", ".join(str(x) + "f" for x in self.scales), self.channelDim)
280
281
282# An operand that can be fed into operations. Also, an operand is always
283# declared before operations.
284class Operand(NamedVariable):
285
286    def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None):
287        NamedVariable.__init__(self, name, sep="", skipRenaming=skipRenaming)
288        if type(opType) is str:
289            self.type = Type.GetTypeFromString(opType, value, extraParams)
290            value = backward
291        else:
292            self.type = Type.GetType(*opType, extraParams=extraParams)
293        self.SetValue(value)
294        self.lifetime = "TEMPORARY_VARIABLE"
295        self.model_index = None
296        self.ins = []
297        self.outs = []
298        self.mayBeInternal = True
299
300    def SetValue(self, value):
301        self.value = value if type(value) is list or type(value) is tuple or value is None \
302                     else [value]
303        return self
304
305    def SetValueFromNumpy(self, value):
306        self.value = value.flatten().tolist()
307        return self
308
309    def GetValueAsNumpy(self):
310        return np.array(self.value).reshape(self.type.dimensions)
311
312    # Print value as cpp-style list initialization
313    def GetListInitialization(self):
314        if self.value is None:
315            return "{}"
316        elif self.type.IsFloat():
317            return "{%s}"%(GetJointStr(self.value, method=PrettyPrintAsFloat))
318        elif self.type.IsBool():
319            return "{%s}"%(GetJointStr(self.value, method=lambda v: "true" if v else "false"))
320        else:
321            return "{%s}"%(GetJointStr(self.value, method=lambda x: str(int(x))))
322
323    def ToUnspecifiedDim(self):
324        self.dimensions = self.type.dimensions
325        self.type = self.type.ToUnspecifiedDim()
326
327    def ConvertTo(self, DerivedClass, name=None):
328        assert issubclass(DerivedClass, Operand)
329        name = self.name if name is None else name
330        newop = DerivedClass(name, self.type.GetSignatureTuple(), skipRenaming=True,
331                             extraParams=self.type.extraParams)
332        if not issubclass(DerivedClass, Internal):
333            newop.SetValue(self.value)
334        if not self.mayBeInternal:
335            assert not issubclass(DerivedClass, Internal)
336            newop.ShouldNeverBeInternal()
337        return newop
338
339    def ShouldNeverBeInternal(self):
340        self.mayBeInternal = False
341        return self
342
343# Base class of user-defined input/output operand
344class InOut(Operand):
345
346    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
347        Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, extraParams=extraParams)
348        self.lifetime = "SUBGRAPH_INPUT"
349        self.index = 0
350
351    def Feed(self, value):
352        self.SetValue(value[self] if type(value) is dict else value)
353        return self
354
355# A user-declared input operand
356class Input(InOut):
357    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
358        InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
359        self.lifetime = "SUBGRAPH_INPUT"
360
361# A user-declared output operand
362class Output(InOut):
363    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
364        InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
365        self.lifetime = "SUBGRAPH_OUTPUT"
366
367# An output that we don't want to compare the results
368class IgnoredOutput(Output):
369    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
370        Output.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
371        self.lifetime = "SUBGRAPH_OUTPUT"
372    def Feed(self, value):
373        numElements = reduce(lambda x,y: x*y, self.type.dimensions, 1)
374        self.value = [0 for x in range(numElements)]
375        return self
376
377# An explicitly declared parameter
378class Parameter(Operand):
379    def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None):
380        Operand.__init__(self, name, opType, value, backward, skipRenaming=skipRenaming,
381                         extraParams=extraParams)
382        self.initializer = NamedVariable(str(self) + "_init")
383        if value is None:
384            self.lifetime = "NO_VALUE"
385        elif Configuration.useSHM():
386            self.lifetime = "CONSTANT_REFERENCE"
387        else:
388            self.lifetime = "CONSTANT_COPY"
389
390# A shortcut for parameters of INT32
391class Int32Scalar(Parameter, ImplicitParameter):
392    def __init__(self, name, value):
393        Parameter.__init__(self, name, ("INT32", []), int(value))
394    @staticmethod
395    def IsCompatible(value):
396        return type(value) is int
397
398# A shortcut for parameters of FLOAT16
399class Float16Scalar(Parameter, ImplicitParameter):
400    def __init__(self, name, value):
401        Parameter.__init__(self, name, ("FLOAT16", []), float(value))
402    @staticmethod
403    def IsCompatible(value):
404        return False
405
406# A shortcut for parameters of FLOAT32
407class Float32Scalar(Parameter, ImplicitParameter):
408    def __init__(self, name, value):
409        Parameter.__init__(self, name, ("FLOAT32", []), float(value))
410    @staticmethod
411    def IsCompatible(value):
412        return type(value) is float
413
414# A shortcut for parameters of BOOL
415class BoolScalar(Parameter, ImplicitParameter):
416    def __init__(self, name, value):
417        Parameter.__init__(self, name, ("BOOL", []), bool(value))
418    @staticmethod
419    def IsCompatible(value):
420        return type(value) is bool
421
422# A shortcut for parameter of 1-D TENSOR_INT32
423class Int32Vector(Parameter, ImplicitParameter):
424    def __init__(self, name, value):
425        Parameter.__init__(self, name, ("TENSOR_INT32", [len(value)]), [int(v) for v in value])
426    @staticmethod
427    def IsCompatible(value):
428        if type(value) is not list and type(value) is not tuple:
429            return False
430        return all(type(i) is int for i in value)
431
432# A shortcut for parameter of 1-D TENSOR_FLOAT32
433class Float32Vector(Parameter, ImplicitParameter):
434    def __init__(self, name, value):
435        Parameter.__init__(self, name, ("TENSOR_FLOAT32", [len(value)]), [float(v) for v in value])
436    @staticmethod
437    def IsCompatible(value):
438        if type(value) is not list and type(value) is not tuple:
439            return False
440        return all(type(i) is float for i in value)
441
442# A shortcut for a SUBGRAPH parameter
443class SubgraphReference(Parameter, ImplicitParameter):
444    def __init__(self, name, model):
445        Parameter.__init__(self, name, ("SUBGRAPH", []), model)
446        self.lifetime = "SUBGRAPH"
447        if model.name is None:
448            model.name = name
449    @staticmethod
450    def IsCompatible(value):
451        return type(value) is Model
452
453# An explicitly declared intermediate result
454class Internal(Operand):
455    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
456        Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming,
457                         extraParams=extraParams)
458        self.lifetime = "TEMPORARY_VARIABLE"
459
460# An operation in a model, does not need a name
461class Operation:
462
463    def __init__(self, optype, ins, outs):
464        self.optype = optype
465        self.SetInputs(ins)
466        self.SetOutputs(outs)
467
468    # for the ease of debugging
469    def __str__(self):
470        insString = GetJointStr(self.ins)
471        outsString = GetJointStr(self.outs)
472        return "Operation %s: [%s] -> [%s]"%(self.optype, insString, outsString)
473    __repr__ = __str__
474
475    def SetInputs(self, ins):
476        self.ins = [ImplicitParameter.ImplicitConvertion(i) for i in ins]
477        return self
478
479    def SetOutputs(self, outs):
480        self.outs = list(outs)
481        return self
482
483# Main interface
484class Model:
485    models = list()
486
487    def __init__(self, name=None):
488        self.name = name
489        self.operations = []
490        self.operands = []
491        self.isRelaxed = False
492        self.compiled = False
493        self.dumped = False
494        self.version = FileNames.version
495        self.referenced_models = None
496        Model.models.append(self)
497
498    def WithSuffix(self, *args):
499        self.createFunctionName = GlobalVariable("CreateModel", self.name, *args)
500        self.createTestFunctionName = GlobalVariable("createTestModel", self.name, *args)
501        self.isIgnoredFunctionName = GlobalVariable("is_ignored", self.name, *args)
502        return self
503
504    def AddOperand(self, operand):
505        if operand not in self.operands:
506            self.operands.append(operand)
507        return self
508
509    # Makes sure the model contains all (and only) the given inputs in the
510    # specified order.
511    def IdentifyInputs(self, *args):
512        for arg in args:
513            self.AddOperand(arg)
514        inputs = tuple(self.GetInputs())
515        assert inputs == args, '{} vs {}'.format(inputs, args)
516        return self
517
518    # Makes sure the model contains all (and only) the given outputs in the
519    # specified order.
520    def IdentifyOutputs(self, *args):
521        for arg in args:
522            self.AddOperand(arg)
523        outputs = tuple(self.GetOutputs())
524        assert outputs == args, '{} vs {}'.format(outputs, args)
525        return self
526
527    def AddOperation(self, operation):
528        self.operations.append(operation)
529        for i in operation.ins:
530            self.AddOperand(i)
531        for o in operation.outs:
532            self.AddOperand(o)
533        return self
534
535    def Operation(self, op_name, *args):
536        return self.AddOperation(Operation(op_name, args, []))
537
538    def To(self, *args):
539        assert len(self.operations) > 0
540        if type(args[0]) is tuple or type(args[0]) is list:
541            outs = args[0]
542        else:
543            outs = args
544        self.operations[-1].SetOutputs(outs)
545        for o in outs:
546            self.AddOperand(o)
547        return self
548
549    def RelaxedExecution(self, isRelaxed):
550        self.isRelaxed = isRelaxed
551        return self
552
553    # Sets the version of the model in compliance tests. Set to None to disable the test.
554    def IntroducedIn(self, ver):
555        self.version = ver
556        return self
557
558    def GetTypes(self):
559        return sorted(list(set(op.type for op in self.operands)))
560
561    def GetInputs(self):
562        return [i for i in self.operands if isinstance(i, Input)]
563
564    def GetOutputs(self):
565        return [o for o in self.operands if isinstance(o, Output)]
566
567    def GetInputsIndex(self):
568        return [i for i,op in enumerate(self.operands) if isinstance(op, Input)]
569
570    def GetOutputsIndex(self):
571        return [o for o,op in enumerate(self.operands) if isinstance(op, Output)]
572
573    def GetIndexOfOperands(self, operands):
574        return [self.operands.index(i) for i in operands]
575
576    def GetIgnoredOutputs(self):
577        return [o for o in self.operands if isinstance(o, IgnoredOutput)]
578
579    def GetParameters(self):
580        return [p for p in self.operands if isinstance(p, Parameter)]
581
582    def GetReferencedModels(self):
583        assert self.compiled
584        return self.referenced_models
585
586    def GetEquivalentOperands(self, targets):
587        return [self.operands[self.operands.index(t)] for t in targets]
588
589    def UpdateEquivalentOperands(self, targets):
590        for t in targets:
591            self.operands[self.operands.index(t)] = t
592        return self
593
594    def SetOperandIndex(self):
595        for ind, i in enumerate(self.GetInputs()):
596            i.index = ind
597        for ind, o in enumerate(self.GetOutputs()):
598            o.index = ind
599        for ind, op in enumerate(self.operands):
600            op.model_index = ind
601        return self
602
603    def SetOperandInsAndOuts(self):
604        for op in self.operands:
605            op.ins = list()
606            op.outs = list()
607        for op in self.operations:
608            op.ins = self.GetEquivalentOperands(op.ins)
609            op.outs = self.GetEquivalentOperands(op.outs)
610            for i in op.ins:
611                i.outs.append(op)
612            for o in op.outs:
613                o.ins.append(op)
614        return self
615
616    def TopologicalSortHelper(self, op, deps, visited):
617        if op in visited:
618            assert op not in deps, "Cycle detected in the graph"
619        else:
620            visited.add(op)
621            for i in deps[op]:
622                self.TopologicalSortHelper(i, deps, visited)
623            self.operations.append(op)
624            deps.pop(op)
625
626    # Topological sort of the operations, and detect if there is a cycle is the graph
627    def TopologicalSort(self):
628        deps = {op: list() for op in self.operations}
629        [deps[o].append(i) for op in self.operands for o in op.outs for i in op.ins]
630        operations = self.operations.copy()
631        self.operations = []
632        visited = set()
633        for op in operations:
634            self.TopologicalSortHelper(op, deps, visited)
635
636    def CompileReferencedModels(self, referenced_models, referenced_model_to_index):
637        for operand in self.operands:
638            if operand.lifetime != "SUBGRAPH":
639                continue
640            model = operand.value[0]
641            key = id(model)
642            if key not in referenced_model_to_index:
643                referenced_model_to_index[key] = len(referenced_model_to_index)
644                referenced_models.append(model)
645                model.Compile(referenced_models, referenced_model_to_index)
646            operand.value = [referenced_model_to_index[key]]
647
648    def Compile(self, referenced_models=None, referenced_model_to_index=None):
649        if self.compiled:
650            return self
651        if referenced_models is None:
652            # This is the main model.
653            referenced_models = []
654            referenced_model_to_index = {}
655            self.referenced_models = referenced_models
656        self.SetOperandIndex()
657        self.SetOperandInsAndOuts()
658        self.TopologicalSort()
659        self.CompileReferencedModels(referenced_models, referenced_model_to_index)
660        # Do not check compliance for relaxed mode tests.
661        if self.isRelaxed:
662            self.IntroducedIn(None)
663        self.compiled = True
664        return self
665
666    def Feed(self, feedDict):
667        for i in self.GetInputs():
668            i.Feed(feedDict[0])
669        for o in self.GetOutputs():
670            o.Feed(feedDict[1])
671        return self
672
673# To track implicitly convertible variation types
674class ImplicitVariation:
675    @staticmethod
676    def ImplicitConvertion(value):
677        if isinstance(value, ModelVariation):
678            return value
679        for implicitType in ImplicitVariation.__subclasses__():
680            value = value if type(value) is tuple or type(value) is list else [value]
681            if implicitType.IsCompatible(value[0]):
682                var = implicitType(value[0])
683                if len(value) > 1:
684                    var.Identify(*value[1:])
685                return var
686        assert False, "%s not supported for implicit variation"%value[0]
687
688# An exception indicating that the current variation list should be skipped.
689class SkipVariation(Exception):
690    pass
691
692# The base class for model variations
693class ModelVariation:
694    supportsSubgraphs = False
695
696    def __init__(self, name=None):
697        self.targetOperands = {}
698        self.name = name
699
700    # Apply the model variation.
701    def ApplyTo(self, model):
702        assert not model.compiled
703        assert not model.dumped
704
705        if not self.supportsSubgraphs:
706            containsSubgraphs = any(operand.lifetime == "SUBGRAPH" for operand in model.operands)
707            assert not containsSubgraphs, "Variation {} does not support subgraphs".format(
708                self.__class__.__name__)
709
710        if not self.targetOperands:
711            self.AutoIdentify(model)
712
713        # Transform operands and model.
714        targets = model.GetEquivalentOperands(sorted(self.targetOperands.keys()))
715        model.UpdateEquivalentOperands(
716            [self.TransformOperand(op, self.targetOperands[op]) for op in targets])
717        model = self.TransformModel(model)
718        return model
719
720    def IdentifyOperands(self, args=None):
721        if args is None:
722            return self
723        self.targetOperands = args if type(args) is dict else {i: None for i in args}
724        return self
725
726    def Identify(self, operandArgs=None, paramArgs=None):
727        self.IdentifyOperands(operandArgs)
728        return self
729
730    # Set variation to its default name
731    def SetToDefaultName(self):
732        self.name = ""
733        return self
734
735    # Automatically select the target operand list
736    def AutoIdentify(self, model):
737        return self
738
739    # Transform operands that are marked by IdentifyOperands()
740    def TransformOperand(self, op, arg=None):
741        return op
742
743    # Transform the model
744    def TransformModel(self, model):
745        return model
746
747# Default variation that does nothing
748class DefaultVariation(ModelVariation):
749    supportsSubgraphs = True
750
751    def __init__(self, name=None):
752        ModelVariation.__init__(self, name=name)
753
754# Convert operand data type
755class DataTypeConverter(ModelVariation, ImplicitVariation):
756    supportsSubgraphs = True
757
758    def __init__(self, targetType=None, name=None, scale=None, zeroPoint=None):
759        ModelVariation.__init__(self, name=name)
760        if targetType is not None:
761            assert DataTypeConverter.IsCompatible(targetType)
762        self.targetType = targetType
763        self.scale = scale
764        self.zeroPoint = zeroPoint
765
766    @staticmethod
767    def IsCompatible(value):
768        return value.lower() in ["float16", "int32", "quant8", "quant8_signed"]
769
770    def SetToDefaultName(self):
771        if self.targetType is not None:
772            self.name = self.targetType.lower()
773            return self
774        targetTypes = list(zip(*(arg for arg in self.targetOperands.values()
775                                 if type(arg) is not DataTypeConverter)))[0]
776        if "TENSOR_QUANT8_SYMM_PER_CHANNEL" in targetTypes:
777            self.name = "channelQuant8"
778        elif "TENSOR_QUANT8_ASYMM" in targetTypes:
779            self.name = "quant8"
780        elif "TENSOR_QUANT8_ASYMM_SIGNED" in targetTypes:
781            self.name = "quant8_signed"
782        elif "TENSOR_INT32" in targetTypes:
783            self.name = "int32"
784        elif "TENSOR_FLOAT16" in targetTypes:
785            self.name = "float16"
786        else:
787            self.name = "float32"
788        return self
789
790    def AutoIdentify(self, model):
791        if self.targetType is not None:
792            if self.targetType == "quant8" or self.targetType == "quant8_signed":
793                if self.targetType == "quant8":
794                    tensorType = "TENSOR_QUANT8_ASYMM"
795                else:
796                    tensorType = "TENSOR_QUANT8_ASYMM_SIGNED"
797                assert self.scale is not None
798                assert self.zeroPoint is not None
799                tensorType = [tensorType, self.scale, self.zeroPoint]
800                scalarType = None  # Not supported.
801            else:
802                tensorType = ["TENSOR_" + self.targetType.upper()]
803                scalarType = [self.targetType.upper()]
804            # By default, select all the float32 tensors/scalars
805            targets = dict()
806            targets.update({op: DataTypeConverter(self.targetType, self.name,
807                                                  self.scale, self.zeroPoint)
808                            for op in model.operands if op.type.type == "SUBGRAPH"})
809            targets.update({op: tensorType
810                            for op in model.operands if op.type.type == "TENSOR_FLOAT32"})
811            if scalarType is not None:
812                targets.update({op: scalarType
813                                for op in model.operands if op.type.type == "FLOAT32"})
814            self.Identify(targets)
815        return self
816
817    def TransformOperand(self, op, arg=None):
818        if type(arg) is DataTypeConverter:
819            # Handle nested SUBGRAPHs
820            assert len(op.value) == 1
821            assert type(op.value[0]) is Model
822            op.value[0] = arg.ApplyTo(op.value[0])
823            return op
824        if len(arg) == 1:
825            typeTuple = (arg[0], op.type.dimensions)
826        else:
827            typeTuple = (arg[0], op.type.dimensions, *arg[1:])
828        # To handle Internal operands
829        if op.value is None or op.type.GetNumberOfElements() == 0:
830            op.type = Type.GetType(*typeTuple)
831        else:
832            v = Dequantize(op.GetValueAsNumpy().astype(np.float32), op.type)
833            op.type = Type.GetType(*typeTuple)
834            v = Quantize(v, op.type)
835            op.SetValueFromNumpy(v)
836        return op
837
838# Convert model to turn on/off relaxed computation
839class RelaxedModeConverter(ModelVariation, ImplicitVariation):
840    supportsSubgraphs = True
841
842    def __init__(self, isRelaxed=True, name=None):
843        ModelVariation.__init__(self, name=name)
844        if isinstance(isRelaxed, bool):
845            self.isRelaxed = isRelaxed
846        else:
847            assert RelaxedModeConverter.IsCompatible(isRelaxed.lower())
848            self.isRelaxed = True
849
850    @staticmethod
851    def IsCompatible(value):
852        return value.lower() in ["relaxed"]
853
854    def SetToDefaultName(self):
855        self.name = "relaxed" if self.isRelaxed else "float"
856        return self
857
858    def TransformModel(self, model):
859        model.RelaxedExecution(self.isRelaxed)
860        return model
861
862# Convert data layout between "NHWC" amd "NCHW"
863class DataLayoutConverter(ModelVariation, ImplicitVariation):
864
865    def __init__(self, targetLayout="nchw", name=None):
866        ModelVariation.__init__(self, name=name)
867        self.targetLayout = targetLayout.lower()
868        assert DataLayoutConverter.IsCompatible(self.targetLayout)
869        self.perm = (0, 3, 1, 2) if self.targetLayout == "nchw" else (0, 2, 3, 1)
870        self.param = True if self.targetLayout == "nchw" else False
871
872    @staticmethod
873    def IsCompatible(value):
874        return value.lower() in ["nhwc", "nchw"]
875
876    def SetToDefaultName(self):
877        self.name = self.targetLayout
878        return self
879
880    def TransformOperand(self, op, arg=None):
881        if len(op.type.dimensions) == 4:
882            # To handle Internal operands
883            if op.value is not None and op.type.GetNumberOfElements() != 0:
884                op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm))
885            newDim = [op.type.dimensions[i] for i in self.perm]
886            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
887        elif len(op.type.dimensions) == 1 and len(op.value) == 4:
888            op.SetValueFromNumpy(op.GetValueAsNumpy()[list(self.perm)])
889        elif op.type.type == "BOOL":
890            op.SetValue(self.param)
891        else:
892            assert False, "%s not supported by DataLayoutConverter"%op
893        return op
894
895# Convert data by tansposing and removing axis
896class AxisConverter(ModelVariation):
897
898    def __init__(self, origin, target, dim, drop=[], name=None):
899        ModelVariation.__init__(self, name=name)
900        self.origin = origin
901        self.target = target
902        assert all(i >= -dim and i < dim for i in [self.origin, self.target])
903        self.dim = dim
904        self.perm = list(range(dim))
905        self.perm.insert(target if target >= 0 else target + dim, self.perm.pop(origin))
906        self.drop = [drop] if type(drop) is int else list(drop)
907        assert all(i >= -dim and i < dim for i in self.drop)
908        self.drop = [i if i >= 0 else i + dim for i in self.drop]
909        assert target not in self.drop and target + dim not in self.drop
910
911    def SetToDefaultName(self):
912        axis = self.target if self.target >= 0 else self.target + self.dim
913        axis -= sum(i < axis for i in self.drop)
914        neg = "" if self.target >= 0 else "_neg"
915        self.name = "dim%d_axis%d%s"%(self.dim - len(self.drop), axis, neg)
916        return self
917
918    def TransposeAxis(self, op):
919        if op.type.type == "INT32":
920            op.SetValue(self.target)
921        elif len(op.type.dimensions) == self.dim:
922            # To handle Internal operands
923            if op.value is not None:
924                op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm))
925            newDim = [op.type.dimensions[i] for i in self.perm]
926            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
927        else:
928            assert False, "%s not supported by AxisConverter"%op
929        return op
930
931    def RemoveAxis(self, op):
932        if op.type.type == "INT32":
933            if op.value[0] >= 0:
934                op.SetValue(op.value[0] - sum(i < op.value[0] for i in self.drop))
935            else:
936                op.SetValue(op.value[0] + sum(i > (op.value[0] + self.dim) for i in self.drop))
937        elif len(op.type.dimensions) == self.dim:
938            if op.value is not None:
939                val = op.GetValueAsNumpy()
940                for i in sorted(self.drop, reverse=True):
941                    val = np.take(val, 0, axis=i)
942                op.SetValueFromNumpy(val)
943            newDim = [op.type.dimensions[i] for i in range(self.dim) if i not in self.drop]
944            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
945        else:
946            assert False, "%s not supported by AxisConverter"%op
947        return op
948
949    def TransformOperand(self, op, arg=None):
950        op = self.TransposeAxis(op)
951        op = self.RemoveAxis(op)
952        return op
953
954# Convert Output based on activation
955class ActivationConverter(ModelVariation, ImplicitVariation):
956    # (Enum, low, high)
957    actMap = {
958        "none": (0, None, None),
959        "relu": (1, 0.0, None),
960        "relu1": (2, -1.0, 1.0),
961        "relu6": (3, 0.0, 6.0),
962    }
963    def __init__(self, act="relu", name=None):
964        ModelVariation.__init__(self, name=name)
965        self.act = act.lower()
966        assert ActivationConverter.IsCompatible(self.act)
967        self.enum = ActivationConverter.actMap[self.act][0]
968        self.low = ActivationConverter.actMap[self.act][1]
969        self.high = ActivationConverter.actMap[self.act][2]
970
971    @staticmethod
972    def IsCompatible(value):
973        return value.lower() in ActivationConverter.actMap.keys()
974
975    def SetToDefaultName(self):
976        self.name = self.act
977        return self
978
979    def TransformOperand(self, op, arg=None):
980        if op.type.type == "INT32": # activation enum
981            return op.SetValue(self.enum)
982        else:
983            assert isinstance(op, Output)
984            v = op.GetValueAsNumpy()
985            if self.low is not None:
986                low = Quantize(self.low, op.type)
987                v = np.maximum(v, low)
988            if self.high is not None:
989                high = Quantize(self.high, op.type)
990                v = np.minimum(v, high)
991            return op.SetValueFromNumpy(v)
992
993# Convert all constant tensors as model inputs.
994class AllTensorsAsInputsConverter(ModelVariation):
995    supportsSubgraphs = True
996
997    def __init__(self, name=None):
998        ModelVariation.__init__(self, name=name)
999
1000    def SetToDefaultName(self):
1001        self.name = "all_tensors_as_inputs"
1002        return self
1003
1004    def TransformModel(self, model):
1005        if len(model.operations) != 1:
1006            raise SkipVariation
1007
1008        # Find all constant tensors.
1009        tensorParams = [
1010            p for p in model.operands
1011            if type(p) is Parameter and not p.type.IsScalar() and p.value is not None
1012        ]
1013        if not tensorParams:
1014            raise SkipVariation
1015
1016        # Convert to model inputs.
1017        model.UpdateEquivalentOperands([op.ConvertTo(Input) for op in tensorParams])
1018        return model
1019
1020def CompatibleWithADD(op):
1021    return (len(op.type.dimensions) <= 4 and
1022            len(op.value) > 0 and
1023            op.type.type in ["TENSOR_FLOAT32", "TENSOR_QUANT8_ASYMM",
1024                             "TENSOR_FLOAT16", "TENSOR_QUANT8_ASYMM_SIGNED"])
1025
1026# Add a dummy ADD operation before each model input to make it as an internal operand.
1027class AllInputsAsInternalCoverter(ModelVariation):
1028    supportsSubgraphs = True
1029
1030    def __init__(self, name=None):
1031        ModelVariation.__init__(self, name=name)
1032
1033    def SetToDefaultName(self):
1034        self.name = "all_inputs_as_internal"
1035        return self
1036
1037    def TransformModel(self, model):
1038        if len(model.operations) != 1:
1039            raise SkipVariation
1040
1041        # Find all input tensors that can be an output of the ADD operation.
1042        modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal]
1043        if not modelInputs:
1044            raise SkipVariation
1045
1046        # Make every input an output of a dummy operation: input_new ADD dummy = input.
1047        for op in modelInputs:
1048            newInput = op.ConvertTo(Input, name=op.name + "_new")
1049            dummyParam = Parameter("dummy", (op.type.type, [1], op.type.scale, op.type.zeroPoint),
1050                                   [op.type.zeroPoint])
1051            model.Operation("ADD", newInput, dummyParam, 0).To(op)
1052
1053        # Convert to internal operands.
1054        model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelInputs])
1055        return model
1056
1057# Add a dummy ADD operation after each model output to make it as an internal operand.
1058class AllOutputsAsInternalCoverter(ModelVariation):
1059    supportsSubgraphs = True
1060
1061    def __init__(self, name=None):
1062        ModelVariation.__init__(self, name=name)
1063
1064    def SetToDefaultName(self):
1065        self.name = "all_outputs_as_internal"
1066        return self
1067
1068    def TransformModel(self, model):
1069        if len(model.operations) != 1:
1070            raise SkipVariation
1071
1072        # Find all output tensors that can be an input to an ADD operation.
1073        modelOutputs = [o for o in model.GetOutputs() if CompatibleWithADD(o)]
1074        if not modelOutputs:
1075            raise SkipVariation
1076
1077        # Make every output an input of a dummy operation: output ADD dummy = output_new.
1078        for op in modelOutputs:
1079            newOutput = op.ConvertTo(Output, name=op.name + "_new")
1080            dummyParam = Parameter("dummy", (op.type.type, [1], op.type.scale, op.type.zeroPoint),
1081                                   [op.type.zeroPoint])
1082            model.Operation("ADD", op, dummyParam, 0).To(newOutput)
1083
1084        # Convert to internal operands.
1085        model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelOutputs])
1086        return model
1087
1088# An example is always attached to a model, and could have multiple variations
1089class Example:
1090    examples = []
1091    versionOverrides = {}
1092
1093    def __init__(self, *args, model=None, name=None):
1094        self.model = Model.models[-1] if model is None else model
1095        self.name = name
1096        self.expectedMultinomialDistributionTolerance = 0
1097        self.expectFailure = False
1098        self.testDynamicOutputShape = True
1099        self.testLifeTimeVariation = True
1100        self.feedDicts = []
1101        for feedDict in args:
1102            if type(feedDict) is tuple or type(feedDict) is list:
1103                self.feedDicts.append(feedDict)
1104            elif type(feedDict) is dict:
1105                self.feedDicts.append((
1106                    {i: feedDict[i] for i in self.model.GetInputs()},
1107                    {o: feedDict[o] for o in self.model.GetOutputs()}
1108                ))
1109            else:
1110                assert False
1111        self.variations = []
1112        Example.examples.append(self)
1113
1114    @staticmethod
1115    def SetVersion(ver, *args):
1116        for name in args:
1117            Example.versionOverrides[name] = ver
1118
1119    # Main entrance of test generator
1120    @staticmethod
1121    def DumpAllExamples(DumpModel=None, model_fd=None,
1122                        DumpExample=None, example_fd=None,
1123                        DumpTest=None, test_fd=None):
1124        Example.CombineAllExamples()
1125        for example in Example.examples:
1126            example.Dump(DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd)
1127
1128    # Combine examples with the same model, same name, and same set of variations
1129    @staticmethod
1130    def CombineAllExamples():
1131        modelMap = {}
1132        newExamples = []
1133        for example in Example.examples:
1134            key = (example.model, example.name, tuple(tuple(e) for e in example.variations))
1135            if key in modelMap:
1136                modelMap[key].Combine(example)
1137            else:
1138                modelMap[key] = example
1139                newExamples.append(example)
1140        Example.examples = newExamples
1141
1142    def AddVariations(self, *args, includeDefault=True, defaultName=None):
1143        self.variations.append([DefaultVariation(defaultName)] if includeDefault else [])
1144        self.variations[-1].extend(ImplicitVariation.ImplicitConvertion(i) for i in args)
1145        return self
1146
1147    def AddNchw(self, *args, includeDefault=True, defaultName="nhwc"):
1148        var = DataLayoutConverter("nchw").Identify(args)
1149        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1150        return self
1151
1152    def AddRelaxed(self, isRelaxed=True, includeDefault=True, defaultName=None):
1153        var = RelaxedModeConverter(isRelaxed)
1154        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1155        return self
1156
1157    def AddRelu(self, *args, includeDefault=True, defaultName=None):
1158        var = ActivationConverter("relu").Identify(args)
1159        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1160        return self
1161
1162    def AddAllActivations(self, *args):
1163        var = [ActivationConverter(i).Identify(args)
1164            for i in sorted(ActivationConverter.actMap.keys())]
1165        self.AddVariations(*var, includeDefault=False)
1166        return self
1167
1168    def GuessOriginalAxisAndDim(self, *args):
1169        origin = None
1170        dim = None
1171        for arg in args:
1172            if arg.type.type == "INT32":
1173                origin = arg.value[0]
1174            else:
1175                if dim is None:
1176                    dim = len(arg.type.dimensions)
1177                else:
1178                    assert dim == len(arg.type.dimensions)
1179        assert dim is not None
1180        origin = dim - 1 if origin is None else origin
1181        origin = origin + dim if origin < 0 else origin
1182        return origin, dim
1183
1184    def AddAxis(self, axis, *args, includeDefault=True, defaultName=None):
1185        origin, dim = self.GuessOriginalAxisAndDim(*args)
1186        axis = [axis] if type(axis) is int else list(axis)
1187        var = [AxisConverter(origin, a, dim).Identify(args) for a in axis]
1188        self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName)
1189        return self
1190
1191    def AddAllPositiveAxis(self, *args):
1192        origin, dim = self.GuessOriginalAxisAndDim(*args)
1193        var = [AxisConverter(origin, a, dim).Identify(args) for a in range(dim)]
1194        self.AddVariations(*var, includeDefault=False)
1195        return self
1196
1197    def AddAllAxis(self, *args):
1198        origin, dim = self.GuessOriginalAxisAndDim(*args)
1199        var = [AxisConverter(origin, a, dim).Identify(args) for a in range(-dim, dim)]
1200        self.AddVariations(*var, includeDefault=False)
1201        return self
1202
1203    def AddDims(self, dims, *args, includeDefault=True, defaultName=None):
1204        origin, dim = self.GuessOriginalAxisAndDim(*args)
1205        dims = [dims] if type(dims) is int else list(dims)
1206        drop = list(range(dim))
1207        drop.pop(origin)
1208        var = [AxisConverter(origin, origin, dim, drop[0:(dim-i)]).Identify(args) for i in dims]
1209        self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName)
1210        return self
1211
1212    def AddAllDims(self, *args):
1213        origin, dim = self.GuessOriginalAxisAndDim(*args)
1214        drop = list(range(dim))
1215        drop.pop(origin)
1216        var = [AxisConverter(origin, origin, dim, drop[0:i]).Identify(args) for i in range(dim)]
1217        self.AddVariations(*var, includeDefault=False)
1218        return self
1219
1220    def AddAllDimsAndPositiveAxis(self, *args):
1221        origin, dim = self.GuessOriginalAxisAndDim(*args)
1222        var = [AxisConverter(origin, j, dim, range(i)).Identify(args) \
1223                for i in range(dim) for j in range(i, dim)]
1224        self.AddVariations(*var, includeDefault=False)
1225        return self
1226
1227    def AddAllDimsAndAxis(self, *args):
1228        origin, dim = self.GuessOriginalAxisAndDim(*args)
1229        var = [AxisConverter(origin, k, dim, range(i)).Identify(args) \
1230                for i in range(dim) for j in range(i, dim) for k in [j, j - dim]]
1231        self.AddVariations(*var, includeDefault=False)
1232        return self
1233
1234    def Combine(self, other):
1235        assert self.model is other.model, "Only examples targetting the same model can be combined"
1236        assert tuple(self.variations) == tuple(other.variations), \
1237            "Only examples with the same set of variations can be combined"
1238        assert self.name == other.name, "Only examples with the same name can be combined"
1239        self.feedDicts.extend(other.feedDicts)
1240        return self
1241
1242    def Dump(self, DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd):
1243        if self.testLifeTimeVariation and len(self.model.operations) == 1 and \
1244                self.expectedMultinomialDistributionTolerance == 0:
1245            self.AddVariations(AllTensorsAsInputsConverter())
1246            self.AddVariations(AllInputsAsInternalCoverter())
1247        [v.SetToDefaultName() for vs in self.variations for v in vs if v.name is None]
1248
1249        for feedDict in self.feedDicts:
1250            self.model.Feed(feedDict)
1251            for variationList in itertools.product(*self.variations):
1252                modelOrigin = self.model
1253                self.model = copy.deepcopy(self.model)
1254
1255                # Apply variations
1256                try:
1257                    for variation in variationList:
1258                        self.model = variation.ApplyTo(self.model)
1259                except SkipVariation:
1260                    self.model = modelOrigin
1261                    continue
1262
1263                # Concat names for test and examples
1264                varNames = [v.name for v in variationList]
1265                self.testName = NamedTest(FileNames.specName, self.model.name, self.name, *varNames)
1266                self.examplesName = GlobalVariable("test_model", self.model.name, self.name,
1267                                                   *varNames)
1268                if str(self.testName) in Example.versionOverrides:
1269                    self.model.IntroducedIn(Example.versionOverrides[str(self.testName)])
1270                self.model.WithSuffix(*varNames).Compile()
1271
1272                # Dump files
1273                if DumpExample is not None and example_fd is not None:
1274                    DumpExample(self, example_fd)
1275                if DumpTest is not None and test_fd is not None:
1276                    DumpTest(self, test_fd)
1277
1278                # Restore model before variation
1279                self.model = modelOrigin
1280        return self
1281
1282    # Specifies the RANDOM_MULTINOMIAL distribution tolerance.
1283    # If set to greater than zero, the input is compared as log-probabilities
1284    # to the output and must be within this tolerance to pass.
1285    def WithMultinomialDistributionTolerance(self, expectedTolerance):
1286        assert self.expectFailure is False
1287        self.expectedMultinomialDistributionTolerance = expectedTolerance
1288        return self
1289
1290    # Specifies that this example is expected to fail during compilation or execution.
1291    def ExpectFailure(self):
1292        assert self.expectedMultinomialDistributionTolerance == 0
1293        self.expectFailure = True
1294        return self
1295
1296    def DisableDynamicOutputShapeVariation(self):
1297        self.testDynamicOutputShape = False
1298        return self
1299
1300    def DisableLifeTimeVariation(self):
1301        self.testLifeTimeVariation = False
1302        return self
1303
1304class FileNames:
1305    specFiles = []
1306    specNames = []
1307    exampleFiles = []
1308    specFile = ""
1309    specName = ""
1310    exampleFile = ""
1311    version = ""
1312    fileIndex = 0
1313
1314    @staticmethod
1315    def InitializeFileLists(spec, example):
1316        # get all spec files and target files
1317        if os.path.isfile(spec):
1318            FileNames.specFiles = [os.path.abspath(spec)]
1319        elif os.path.isdir(spec):
1320            FileNames.specFiles = sorted([os.path.abspath(os.path.join(spec, f))
1321                for f in os.listdir(spec) if f.endswith(".mod.py")])
1322        else:
1323            assert False, "%s is neither a file or a directory"%spec
1324        FileNames.specNames = [re.sub(r"\..*", "", os.path.basename(f))
1325            for f in FileNames.specFiles]
1326        FileNames.exampleFiles = FileNames.ParseTargetFiles(example, ".example.cpp")
1327
1328    @staticmethod
1329    def ParseTargetFiles(arg, ext):
1330        numFiles = len(FileNames.specFiles)
1331        if arg is None:
1332            return [None] * numFiles
1333        absPath = os.path.abspath(arg)
1334        if os.path.isdir(arg):
1335            target = [os.path.join(absPath, f + ext) for f in FileNames.specNames]
1336        elif arg == "-":
1337            target = ["-"] * numFiles
1338        else:
1339            target = [absPath] * numFiles
1340        return target
1341
1342    @staticmethod
1343    def NextFile():
1344        if FileNames.fileIndex >= len(FileNames.specFiles):
1345            return False
1346        FileNames.specFile = FileNames.specFiles[FileNames.fileIndex]
1347        FileNames.specName = FileNames.specNames[FileNames.fileIndex]
1348        FileNames.exampleFile = FileNames.exampleFiles[FileNames.fileIndex]
1349        FileNames.fileIndex += 1
1350        NamedObject.existingNames = set()
1351        NamedVariable.existingNames = set()
1352        NamedTest.existingNames = set()
1353        Type.typesMap = dict()
1354        Model.models = list()
1355        Example.examples = list()
1356        Configuration.use_shm_for_weights = False
1357
1358        # Extract version from absolute file path.
1359        versionMatch = re.findall(r"/V\d_\d/", FileNames.specFile)
1360        if len(versionMatch) == 1:
1361            FileNames.version = versionMatch[0].strip('/')
1362        else:
1363            FileNames.version = None
1364        return True
1365
1366class Configuration:
1367    use_shm_for_weights = False
1368    hook_mode = False
1369
1370    @staticmethod
1371    def useSHM():
1372        return Configuration.use_shm_for_weights
1373
1374def GetTestGeneratorMTime():
1375    tgFiles = ['test_generator.py', 'example_generator.py']
1376    tgDir = os.path.dirname(__file__)
1377    return max(os.path.getmtime(os.path.join(tgDir, filename))
1378               for filename in tgFiles)
1379
1380def MightNeedRegeneration():
1381    specTime = os.path.getmtime(FileNames.specFile)
1382    tgTime = GetTestGeneratorMTime()
1383    return not os.path.exists(FileNames.exampleFile) or \
1384           os.path.getmtime(FileNames.exampleFile) <= max(specTime, tgTime)
1385
1386def Read(filename):
1387    with open(filename) as reader:
1388        return reader.read()
1389
1390def AtomicWrite(filename, data):
1391    # os.replace(src, dest) may fail if src and dest are on diffrent
1392    # filesystems.
1393    tempFile = filename + '.tmp'
1394    try:
1395        with open(tempFile, 'w') as writer:
1396            writer.write(data)
1397        os.replace(tempFile, filename)
1398        tempFile = None
1399    finally:
1400        if tempFile is not None and os.path.exists(tempFile):
1401            os.remove(tempFile)
1402
1403def GetExecScope():
1404    return dict(
1405        ActivationConverter=ActivationConverter,
1406        AllInputsAsInternalCoverter=AllInputsAsInternalCoverter,
1407        AllOutputsAsInternalCoverter=AllOutputsAsInternalCoverter,
1408        AllTensorsAsInputsConverter=AllTensorsAsInputsConverter,
1409        BoolScalar=BoolScalar,
1410        Configuration=Configuration,
1411        DataLayoutConverter=DataLayoutConverter,
1412        DataTypeConverter=DataTypeConverter,
1413        Example=Example,
1414        Float16Scalar=Float16Scalar,
1415        Float32Scalar=Float32Scalar,
1416        Float32Vector=Float32Vector,
1417        IgnoredOutput=IgnoredOutput,
1418        Input=Input,
1419        Int32Scalar=Int32Scalar,
1420        Int32Vector=Int32Vector,
1421        Internal=Internal,
1422        Model=Model,
1423        Operand=Operand,
1424        Output=Output,
1425        Parameter=Parameter,
1426        RelaxedModeConverter=RelaxedModeConverter,
1427        SubgraphReference=SubgraphReference,
1428        SymmPerChannelQuantParams=SymmPerChannelQuantParams)
1429
1430def ArgumentParser():
1431    parser = argparse.ArgumentParser()
1432    parser.add_argument("spec", help="the spec file or directory")
1433    parser.add_argument("--hook", help="hook mode", action='store_true')
1434    return parser
1435
1436def ParseArgs(parser):
1437    args = parser.parse_args()
1438    Configuration.hook_mode = args.hook
1439    return args
1440
1441def Run(InitializeFiles=None, DumpExample=None):
1442    exec_scope = GetExecScope()
1443    while FileNames.NextFile():
1444        try:
1445            if not MightNeedRegeneration():
1446                continue
1447            exec(Read(FileNames.specFile), exec_scope)
1448            example_buf = io.StringIO() if FileNames.exampleFile else None
1449            InitializeFiles(example_fd=example_buf)
1450            Example.DumpAllExamples(DumpExample=DumpExample, example_fd=example_buf)
1451            if FileNames.exampleFile is None:
1452                continue
1453            if Configuration.hook_mode and (not os.path.exists(FileNames.exampleFile) or
1454                                            Read(FileNames.exampleFile) != example_buf.getvalue()):
1455                print(('\n{filename} is out of date. '
1456                        'Please run {generate_all_tests_sh} before uploading.\n').format(
1457                                filename=FileNames.exampleFile,
1458                                generate_all_tests_sh=os.path.abspath(os.path.join(
1459                                        os.path.dirname(__file__), '..', '..', 'runtime', 'test',
1460                                        'specs', 'generate_all_tests.sh'))))
1461                sys.exit(1)
1462            AtomicWrite(FileNames.exampleFile, example_buf.getvalue())
1463        except Exception:
1464            traceback.print_exc()
1465            sys.exit("Exception raised when processing {}".format(FileNames.specFile))
1466