• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""NN model compiler
18
19Contain classes definition and utilify functions for compiling models and
20examples into NDK-based CTS and VTS unit tests.
21
22Used by example_generator.py and spec_visualizer.py
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28import copy
29from functools import reduce
30import argparse
31import io
32import itertools
33import os
34import re
35import sys
36import traceback
37import numpy as np
38
39def GetJointStr(l, sep=", ", method=str):
40    return sep.join([method(i) for i in l])
41
42# Print in C float literal format
43def PrettyPrintAsFloat(x):
44    s = str(float(x))
45    if s.find(".") >= 0 or s.find("e") >= 0:
46        return s + "f"
47    else:
48        return s + ".0f"
49
50# Transform from original type to float32
51def Dequantize(v, ty):
52    v -= ty.zeroPoint
53    if ty.scale != 0:
54        v *= ty.scale
55    if isinstance(ty.extraParams, SymmPerChannelQuantParams):
56        v *= ty.extraParams.GetScalesBroadcastArray(ty.dimensions)
57    return v
58
59# Transform float32 to target data type
60def Quantize(v, ty):
61    if ty.scale != 0:
62        v /= ty.scale
63    if isinstance(ty.extraParams, SymmPerChannelQuantParams):
64        v = v / ty.extraParams.GetScalesBroadcastArray(ty.dimensions)
65    v += ty.zeroPoint
66    if not ty.IsFloat():
67        v = np.round(v)
68        v = v.astype(int)
69
70    if ty.type == "TENSOR_QUANT8_ASYMM":
71        v = np.minimum(np.maximum(v, 0), 255)
72    elif ty.type == "TENSOR_QUANT16_ASYMM":
73        v = np.minimum(np.maximum(v, 0), 65535)
74    elif ty.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL":
75        v = np.minimum(np.maximum(v, -127), 127)
76    elif ty.type == "UINT32":
77        v = np.maximum(v, 0)
78    elif ty.type == "TENSOR_QUANT8_ASYMM_SIGNED":
79        v = np.minimum(np.maximum(v, -128), 127)
80    return v
81
82# Tracking objects inside a model with a unique name
83class NamedObject:
84    existingNames = set()
85
86    def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False):
87        name = GetJointStr([i for i in args if i is not None and i != ""], sep=sep)
88        if skipRenaming:
89            self.name = name
90            return
91        # make the name unique by renaming with a suffix number
92        uniqueName = name if showZero is False else name + sep + str(startsFrom)
93        while uniqueName in self.__class__.existingNames:
94            startsFrom += 1
95            uniqueName = name + sep + str(startsFrom)
96        self.__class__.existingNames.add(uniqueName)
97        self.name = uniqueName
98
99    def __str__(self):
100        return self.name
101    __repr__ = __str__
102
103    # Since names are unique, objects with the same name are considered equal
104    def __eq__(self, other):
105        return isinstance(other, NamedObject) and self.name == other.name
106
107    def __ne__(self, other):
108        return not self.__eq__(other)
109
110    def __hash__(self):
111        return hash(self.name)
112
113    def __lt__(self, other):
114        return self.name < other.name
115
116# Types, operands should all have a unique name since they share the same namespace
117class NamedVariable(NamedObject):
118    existingNames = set()
119    def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False):
120        NamedObject.__init__(self, *args, sep=sep, showZero=showZero,
121            startsFrom=startsFrom, skipRenaming=skipRenaming)
122
123# Global variables in the spec namespace such as CreateModel, is_ignored, and examples
124class GlobalVariable(NamedVariable):
125    def __init__(self, *args, skipRenaming=False):
126        NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming)
127
128# Each test should have a unique name, but will not conflict with variables
129class NamedTest(NamedObject):
130    existingNames = set()
131    def __init__(self, *args, startsFrom=0, skipRenaming=False):
132        NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming)
133
134class Type(NamedVariable):
135    typesMap = dict()
136    typeLookup = {
137        "INT32": "int32_t",
138        "UINT32": "uint32_t",
139        "FLOAT32": "float",
140        "FLOAT16": "_Float16",
141        "TENSOR_INT32": "int32_t",
142        "TENSOR_FLOAT16": "_Float16",
143        "TENSOR_FLOAT32": "float",
144        "TENSOR_QUANT8_ASYMM": "uint8_t",
145        "TENSOR_QUANT8_SYMM": "int8_t",
146        "BOOL": "bool8",
147        "TENSOR_QUANT16_ASYMM": "uint16_t",
148        "TENSOR_QUANT16_SYMM": "int16_t",
149        "TENSOR_BOOL8": "bool8",
150        "TENSOR_QUANT8_SYMM_PER_CHANNEL": "int8_t",
151        "TENSOR_QUANT8_ASYMM_SIGNED": "int8_t",
152    #     "OEM_SCALAR": this is service-defined.
153        "TENSOR_OEM_BYTE": "uint8_t",
154        "SUBGRAPH": "uint32_t",  # Index into TestModel::referenced.
155    }
156
157    # types are named as "type0", "type1", ...
158    def __init__(self, vt, dimensions, scale, zeroPoint, name="type", skipRenaming=False,
159                 extraParams=None):
160        NamedVariable.__init__(self, name, sep="", showZero=True, skipRenaming=skipRenaming)
161        self.type = vt
162        self.dimensions = dimensions
163        self.scale = float(scale)
164        self.zeroPoint = int(zeroPoint)
165        self.extraParams = extraParams
166
167    # Factory for Type object, only create a new Type if requested type does
168    # not have a match with all existing types
169    @staticmethod
170    def GetType(vt, dimensions, scale=0, zeroPoint=0, extraParams=None):
171        assert isinstance(dimensions, (list, tuple)), \
172            'dimensions must be a list or tuple, got {}'.format(type(dimensions))
173        key = ",".join([vt, str(dimensions), str(scale), str(zeroPoint), str(extraParams)])
174        if key not in Type.typesMap:
175            Type.typesMap[key] = Type(vt, dimensions, scale, zeroPoint, extraParams=extraParams)
176        return Type.typesMap[key]
177
178    @staticmethod
179    def GetAllTypes():
180        # sort to ensure a stable order when dumping the code
181        return sorted(Type.typesMap.values())
182
183    # For backward-compatibility
184    @staticmethod
185    def GetTypeFromString(vt, shape, extraParams=None):
186        dimensions, scale, zeroPoint = Type.GetParsedShape(shape)
187        scale = float(scale)
188        zeroPoint = int(zeroPoint)
189        return Type.GetType(vt, dimensions, scale, zeroPoint, extraParams)
190
191    # For backward-compatibility
192    @staticmethod
193    def GetParsedShape(shape):
194        # Parse shape
195        if (shape != "" and shape != "{}"):
196            left, sep, right = shape.partition('{')
197            real_shape, sep, right = right.partition('}')
198            shape = [int(x) for x in real_shape.split(",")]
199            # left now looks like "0.0f, 127.5f, "
200            scale, sep, zero_point = right.rpartition(',')
201            if scale == "":
202                if zero_point == "":
203                    return shape, "0", "0"
204                return shape, zero_point, "0"
205            left, sep, scale = scale.partition(',')
206            return shape, scale.replace("f", ""), zero_point
207        else:
208            return [], "0", "0"
209
210    def GetNumberOfElements(self):
211        return reduce(lambda x,y: x*y, self.dimensions, 1)
212
213    def GetCppTypeString(self):
214        return Type.typeLookup[self.type]
215
216    def IsFloat(self):
217        return self.GetCppTypeString() in ["float", "_Float16"]
218
219    def IsBool(self):
220        return self.GetCppTypeString() == "bool8"
221
222    def IsScalar(self):
223        return not self.type.startswith("TENSOR_")
224
225    def GetSignatureTuple(self):
226        return (self.type, self.dimensions, self.scale, self.zeroPoint)
227
228# To track implicitly convertible parameter types
229class ImplicitParameter():
230    @staticmethod
231    def ImplicitConvertion(value):
232        if isinstance(value, Operand):
233            return value
234        for implicitType in ImplicitParameter.__subclasses__():
235            if implicitType.IsCompatible(value):
236                return implicitType("param", value)
237        assert False, "%s not supported for implicit parameter"%value
238
239
240# ExtraParams with per-channel quantization.
241class SymmPerChannelQuantParams():
242    def __init__(self, channelDim, scales, hide = False):
243        self.channelDim = channelDim
244        self.scales = scales
245        self.hide = hide
246
247    def GetScalesBroadcastArray(self, dimensions):
248        bshape = [1] * len(dimensions)
249        bshape[self.channelDim] = len(self.scales)
250        return np.array(self.scales).reshape(bshape)
251
252
253# An operand that can be fed into operations. Also, an operand is always
254# declared before operations.
255class Operand(NamedVariable):
256
257    def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None):
258        NamedVariable.__init__(self, name, sep="", skipRenaming=skipRenaming)
259        if type(opType) is str:
260            self.type = Type.GetTypeFromString(opType, value, extraParams)
261            value = backward
262        else:
263            self.type = Type.GetType(*opType, extraParams=extraParams)
264        self.SetValue(value)
265        self.lifetime = "TEMPORARY_VARIABLE"
266        self.model_index = None
267        self.ins = []
268        self.outs = []
269        self.mayBeInternal = True
270
271    def SetValue(self, value):
272        self.value = value if type(value) is list or type(value) is tuple or value is None \
273                     else [value]
274        return self
275
276    def SetValueFromNumpy(self, value):
277        self.value = value.flatten().tolist()
278        return self
279
280    def GetValueAsNumpy(self):
281        return np.array(self.value).reshape(self.type.dimensions)
282
283    # Print value as cpp-style list initialization
284    def GetListInitialization(self):
285        if self.value is None:
286            return "{}"
287        elif self.type.IsFloat():
288            return "{%s}"%(GetJointStr(self.value, method=PrettyPrintAsFloat))
289        elif self.type.IsBool():
290            return "{%s}"%(GetJointStr(self.value, method=lambda v: "true" if v else "false"))
291        else:
292            return "{%s}"%(GetJointStr(self.value, method=lambda x: str(int(x))))
293
294    def ConvertTo(self, DerivedClass, name=None):
295        assert issubclass(DerivedClass, Operand)
296        name = self.name if name is None else name
297        newop = DerivedClass(name, self.type.GetSignatureTuple(), skipRenaming=True,
298                             extraParams=self.type.extraParams)
299        if not issubclass(DerivedClass, Internal):
300            newop.SetValue(self.value)
301        if not self.mayBeInternal:
302            assert not issubclass(DerivedClass, Internal)
303            newop.ShouldNeverBeInternal()
304        return newop
305
306    def ShouldNeverBeInternal(self):
307        self.mayBeInternal = False
308        return self
309
310# Base class of user-defined input/output operand
311class InOut(Operand):
312
313    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
314        Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, extraParams=extraParams)
315        self.lifetime = "SUBGRAPH_INPUT"
316        self.index = 0
317
318    def Feed(self, value):
319        self.SetValue(value[self] if type(value) is dict else value)
320        return self
321
322# A user-declared input operand
323class Input(InOut):
324    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
325        InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
326        self.lifetime = "SUBGRAPH_INPUT"
327
328# A user-declared output operand
329class Output(InOut):
330    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
331        InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
332        self.lifetime = "SUBGRAPH_OUTPUT"
333
334# An output that we don't want to compare the results
335class IgnoredOutput(Output):
336    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
337        Output.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
338        self.lifetime = "SUBGRAPH_OUTPUT"
339    def Feed(self, value):
340        numElements = reduce(lambda x,y: x*y, self.type.dimensions, 1)
341        self.value = [0 for x in range(numElements)]
342        return self
343
344# An explicitly declared parameter
345class Parameter(Operand):
346    def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None):
347        Operand.__init__(self, name, opType, value, backward, skipRenaming=skipRenaming,
348                         extraParams=extraParams)
349        self.initializer = NamedVariable(str(self) + "_init")
350        if value is None:
351            self.lifetime = "NO_VALUE"
352        elif Configuration.useSHM():
353            self.lifetime = "CONSTANT_REFERENCE"
354        else:
355            self.lifetime = "CONSTANT_COPY"
356
357# A shortcut for parameters of INT32
358class Int32Scalar(Parameter, ImplicitParameter):
359    def __init__(self, name, value):
360        Parameter.__init__(self, name, ("INT32", []), int(value))
361    @staticmethod
362    def IsCompatible(value):
363        return type(value) is int
364
365# A shortcut for parameters of FLOAT16
366class Float16Scalar(Parameter, ImplicitParameter):
367    def __init__(self, name, value):
368        Parameter.__init__(self, name, ("FLOAT16", []), float(value))
369    @staticmethod
370    def IsCompatible(value):
371        return False
372
373# A shortcut for parameters of FLOAT32
374class Float32Scalar(Parameter, ImplicitParameter):
375    def __init__(self, name, value):
376        Parameter.__init__(self, name, ("FLOAT32", []), float(value))
377    @staticmethod
378    def IsCompatible(value):
379        return type(value) is float
380
381# A shortcut for parameters of BOOL
382class BoolScalar(Parameter, ImplicitParameter):
383    def __init__(self, name, value):
384        Parameter.__init__(self, name, ("BOOL", []), bool(value))
385    @staticmethod
386    def IsCompatible(value):
387        return type(value) is bool
388
389# A shortcut for parameter of 1-D TENSOR_INT32
390class Int32Vector(Parameter, ImplicitParameter):
391    def __init__(self, name, value):
392        Parameter.__init__(self, name, ("TENSOR_INT32", [len(value)]), [int(v) for v in value])
393    @staticmethod
394    def IsCompatible(value):
395        if type(value) is not list and type(value) is not tuple:
396            return False
397        return all(type(i) is int for i in value)
398
399# A shortcut for parameter of 1-D TENSOR_FLOAT32
400class Float32Vector(Parameter, ImplicitParameter):
401    def __init__(self, name, value):
402        Parameter.__init__(self, name, ("TENSOR_FLOAT32", [len(value)]), [float(v) for v in value])
403    @staticmethod
404    def IsCompatible(value):
405        if type(value) is not list and type(value) is not tuple:
406            return False
407        return all(type(i) is float for i in value)
408
409# A shortcut for a SUBGRAPH parameter
410class SubgraphReference(Parameter, ImplicitParameter):
411    def __init__(self, name, model):
412        Parameter.__init__(self, name, ("SUBGRAPH", []), model)
413        self.lifetime = "SUBGRAPH"
414        if model.name is None:
415            model.name = name
416    @staticmethod
417    def IsCompatible(value):
418        return type(value) is Model
419
420# An explicitly declared intermediate result
421class Internal(Operand):
422    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
423        Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming,
424                         extraParams=extraParams)
425        self.lifetime = "TEMPORARY_VARIABLE"
426
427# An operation in a model, does not need a name
428class Operation:
429
430    def __init__(self, optype, ins, outs):
431        self.optype = optype
432        self.SetInputs(ins)
433        self.SetOutputs(outs)
434
435    # for the ease of debugging
436    def __str__(self):
437        insString = GetJointStr(self.ins)
438        outsString = GetJointStr(self.outs)
439        return "Operation %s: [%s] -> [%s]"%(self.optype, insString, outsString)
440    __repr__ = __str__
441
442    def SetInputs(self, ins):
443        self.ins = [ImplicitParameter.ImplicitConvertion(i) for i in ins]
444        return self
445
446    def SetOutputs(self, outs):
447        self.outs = list(outs)
448        return self
449
450# Main interface
451class Model:
452    models = list()
453
454    def __init__(self, name=None):
455        self.name = name
456        self.operations = []
457        self.operands = []
458        self.isRelaxed = False
459        self.compiled = False
460        self.dumped = False
461        self.version = FileNames.version
462        self.referenced_models = None
463        Model.models.append(self)
464
465    def AddOperand(self, operand):
466        if operand not in self.operands:
467            self.operands.append(operand)
468        return self
469
470    # Makes sure the model contains all (and only) the given inputs in the
471    # specified order.
472    def IdentifyInputs(self, *args):
473        for arg in args:
474            self.AddOperand(arg)
475        inputs = tuple(self.GetInputs())
476        assert inputs == args, '{} vs {}'.format(inputs, args)
477        return self
478
479    # Makes sure the model contains all (and only) the given outputs in the
480    # specified order.
481    def IdentifyOutputs(self, *args):
482        for arg in args:
483            self.AddOperand(arg)
484        outputs = tuple(self.GetOutputs())
485        assert outputs == args, '{} vs {}'.format(outputs, args)
486        return self
487
488    def AddOperation(self, operation):
489        self.operations.append(operation)
490        for i in operation.ins:
491            self.AddOperand(i)
492        for o in operation.outs:
493            self.AddOperand(o)
494        return self
495
496    def Operation(self, op_name, *args):
497        return self.AddOperation(Operation(op_name, args, []))
498
499    def To(self, *args):
500        assert len(self.operations) > 0
501        if type(args[0]) is tuple or type(args[0]) is list:
502            outs = args[0]
503        else:
504            outs = args
505        self.operations[-1].SetOutputs(outs)
506        for o in outs:
507            self.AddOperand(o)
508        return self
509
510    def RelaxedExecution(self, isRelaxed):
511        self.isRelaxed = isRelaxed
512        return self
513
514    # Sets the version of the model in compliance tests. Set to None to disable the test.
515    def IntroducedIn(self, ver):
516        self.version = ver
517        return self
518
519    def GetTypes(self):
520        return sorted(list(set(op.type for op in self.operands)))
521
522    def GetInputs(self):
523        return [i for i in self.operands if isinstance(i, Input)]
524
525    def GetOutputs(self):
526        return [o for o in self.operands if isinstance(o, Output)]
527
528    def GetInputsIndex(self):
529        return [i for i,op in enumerate(self.operands) if isinstance(op, Input)]
530
531    def GetOutputsIndex(self):
532        return [o for o,op in enumerate(self.operands) if isinstance(op, Output)]
533
534    def GetIndexOfOperands(self, operands):
535        return [self.operands.index(i) for i in operands]
536
537    def GetIgnoredOutputs(self):
538        return [o for o in self.operands if isinstance(o, IgnoredOutput)]
539
540    def GetParameters(self):
541        return [p for p in self.operands if isinstance(p, Parameter)]
542
543    def GetReferencedModels(self):
544        assert self.compiled
545        return self.referenced_models
546
547    def GetEquivalentOperands(self, targets):
548        return [self.operands[self.operands.index(t)] for t in targets]
549
550    def UpdateEquivalentOperands(self, targets):
551        for t in targets:
552            self.operands[self.operands.index(t)] = t
553        return self
554
555    def SetOperandIndex(self):
556        for ind, i in enumerate(self.GetInputs()):
557            i.index = ind
558        for ind, o in enumerate(self.GetOutputs()):
559            o.index = ind
560        for ind, op in enumerate(self.operands):
561            op.model_index = ind
562        return self
563
564    def SetOperandInsAndOuts(self):
565        for op in self.operands:
566            op.ins = list()
567            op.outs = list()
568        for op in self.operations:
569            op.ins = self.GetEquivalentOperands(op.ins)
570            op.outs = self.GetEquivalentOperands(op.outs)
571            for i in op.ins:
572                i.outs.append(op)
573            for o in op.outs:
574                o.ins.append(op)
575        return self
576
577    def TopologicalSortHelper(self, op, deps, visited):
578        if op in visited:
579            assert op not in deps, "Cycle detected in the graph"
580        else:
581            visited.add(op)
582            for i in deps[op]:
583                self.TopologicalSortHelper(i, deps, visited)
584            self.operations.append(op)
585            deps.pop(op)
586
587    # Topological sort of the operations, and detect if there is a cycle is the graph
588    def TopologicalSort(self):
589        deps = {op: list() for op in self.operations}
590        [deps[o].append(i) for op in self.operands for o in op.outs for i in op.ins]
591        operations = self.operations.copy()
592        self.operations = []
593        visited = set()
594        for op in operations:
595            self.TopologicalSortHelper(op, deps, visited)
596
597    def CompileReferencedModels(self, referenced_models, referenced_model_to_index):
598        for operand in self.operands:
599            if operand.lifetime != "SUBGRAPH":
600                continue
601            model = operand.value[0]
602            key = id(model)
603            if key not in referenced_model_to_index:
604                referenced_model_to_index[key] = len(referenced_model_to_index)
605                referenced_models.append(model)
606                model.Compile(referenced_models, referenced_model_to_index)
607            operand.value = [referenced_model_to_index[key]]
608
609    def Compile(self, referenced_models=None, referenced_model_to_index=None):
610        if self.compiled:
611            return self
612        if referenced_models is None:
613            # This is the main model.
614            referenced_models = []
615            referenced_model_to_index = {}
616            self.referenced_models = referenced_models
617        self.SetOperandIndex()
618        self.SetOperandInsAndOuts()
619        self.TopologicalSort()
620        self.CompileReferencedModels(referenced_models, referenced_model_to_index)
621        # Do not check compliance for relaxed mode tests.
622        if self.isRelaxed:
623            self.IntroducedIn(None)
624        self.compiled = True
625        return self
626
627    def Feed(self, feedDict):
628        for i in self.GetInputs():
629            i.Feed(feedDict[0])
630        for o in self.GetOutputs():
631            o.Feed(feedDict[1])
632        return self
633
634# To track implicitly convertible variation types
635class ImplicitVariation:
636    @staticmethod
637    def ImplicitConvertion(value):
638        if isinstance(value, ModelVariation):
639            return value
640        for implicitType in ImplicitVariation.__subclasses__():
641            value = value if type(value) is tuple or type(value) is list else [value]
642            if implicitType.IsCompatible(value[0]):
643                var = implicitType(value[0])
644                if len(value) > 1:
645                    var.Identify(*value[1:])
646                return var
647        assert False, "%s not supported for implicit variation"%value[0]
648
649# An exception indicating that the current variation list should be skipped.
650class SkipVariation(Exception):
651    pass
652
653# The base class for model variations
654class ModelVariation:
655    supportsSubgraphs = False
656
657    def __init__(self, name=None):
658        self.targetOperands = {}
659        self.name = name
660
661    # Apply the model variation.
662    def ApplyTo(self, model):
663        assert not model.compiled
664        assert not model.dumped
665
666        if not self.supportsSubgraphs:
667            containsSubgraphs = any(operand.lifetime == "SUBGRAPH" for operand in model.operands)
668            assert not containsSubgraphs, "Variation {} does not support subgraphs".format(
669                self.__class__.__name__)
670
671        if not self.targetOperands:
672            self.AutoIdentify(model)
673
674        # Transform operands and model.
675        targets = model.GetEquivalentOperands(sorted(self.targetOperands.keys()))
676        model.UpdateEquivalentOperands(
677            [self.TransformOperand(op, self.targetOperands[op]) for op in targets])
678        model = self.TransformModel(model)
679        return model
680
681    def IdentifyOperands(self, args=None):
682        if args is None:
683            return self
684        self.targetOperands = args if type(args) is dict else {i: None for i in args}
685        return self
686
687    def Identify(self, operandArgs=None, paramArgs=None):
688        self.IdentifyOperands(operandArgs)
689        return self
690
691    # Set variation to its default name
692    def SetToDefaultName(self):
693        self.name = ""
694        return self
695
696    # Automatically select the target operand list
697    def AutoIdentify(self, model):
698        return self
699
700    # Transform operands that are marked by IdentifyOperands()
701    def TransformOperand(self, op, arg=None):
702        return op
703
704    # Transform the model
705    def TransformModel(self, model):
706        return model
707
708# Default variation that does nothing
709class DefaultVariation(ModelVariation):
710    supportsSubgraphs = True
711
712    def __init__(self, name=None):
713        ModelVariation.__init__(self, name=name)
714
715# Convert operand data type
716class DataTypeConverter(ModelVariation, ImplicitVariation):
717    supportsSubgraphs = True
718
719    def __init__(self, targetType=None, name=None, scale=None, zeroPoint=None):
720        ModelVariation.__init__(self, name=name)
721        if targetType is not None:
722            assert DataTypeConverter.IsCompatible(targetType)
723        self.targetType = targetType
724        self.scale = scale
725        self.zeroPoint = zeroPoint
726
727    @staticmethod
728    def IsCompatible(value):
729        return value.lower() in ["float16", "int32", "quant8", "quant8_signed"]
730
731    def SetToDefaultName(self):
732        if self.targetType is not None:
733            self.name = self.targetType.lower()
734            return self
735        targetTypes = list(zip(*(arg for arg in self.targetOperands.values()
736                                 if type(arg) is not DataTypeConverter)))[0]
737        if "TENSOR_QUANT8_SYMM_PER_CHANNEL" in targetTypes:
738            self.name = "channelQuant8"
739        elif "TENSOR_QUANT8_ASYMM" in targetTypes:
740            self.name = "quant8"
741        elif "TENSOR_QUANT8_ASYMM_SIGNED" in targetTypes:
742            self.name = "quant8_signed"
743        elif "TENSOR_INT32" in targetTypes:
744            self.name = "int32"
745        elif "TENSOR_FLOAT16" in targetTypes:
746            self.name = "float16"
747        else:
748            self.name = "float32"
749        return self
750
751    def AutoIdentify(self, model):
752        if self.targetType is not None:
753            if self.targetType == "quant8" or self.targetType == "quant8_signed":
754                if self.targetType == "quant8":
755                    tensorType = "TENSOR_QUANT8_ASYMM"
756                else:
757                    tensorType = "TENSOR_QUANT8_ASYMM_SIGNED"
758                assert self.scale is not None
759                assert self.zeroPoint is not None
760                tensorType = [tensorType, self.scale, self.zeroPoint]
761                scalarType = None  # Not supported.
762            else:
763                tensorType = ["TENSOR_" + self.targetType.upper()]
764                scalarType = [self.targetType.upper()]
765            # By default, select all the float32 tensors/scalars
766            targets = dict()
767            targets.update({op: DataTypeConverter(self.targetType, self.name,
768                                                  self.scale, self.zeroPoint)
769                            for op in model.operands if op.type.type == "SUBGRAPH"})
770            targets.update({op: tensorType
771                            for op in model.operands if op.type.type == "TENSOR_FLOAT32"})
772            if scalarType is not None:
773                targets.update({op: scalarType
774                                for op in model.operands if op.type.type == "FLOAT32"})
775            self.Identify(targets)
776        return self
777
778    def TransformOperand(self, op, arg=None):
779        if type(arg) is DataTypeConverter:
780            # Handle nested SUBGRAPHs
781            assert len(op.value) == 1
782            assert type(op.value[0]) is Model
783            op.value[0] = arg.ApplyTo(op.value[0])
784            return op
785        if len(arg) == 1:
786            typeTuple = (arg[0], op.type.dimensions)
787        else:
788            typeTuple = (arg[0], op.type.dimensions, *arg[1:])
789        # To handle Internal operands
790        if op.value is None or op.type.GetNumberOfElements() == 0:
791            op.type = Type.GetType(*typeTuple)
792        else:
793            v = Dequantize(op.GetValueAsNumpy().astype(np.float32), op.type)
794            op.type = Type.GetType(*typeTuple)
795            v = Quantize(v, op.type)
796            op.SetValueFromNumpy(v)
797        return op
798
799# Convert model to turn on/off relaxed computation
800class RelaxedModeConverter(ModelVariation, ImplicitVariation):
801    supportsSubgraphs = True
802
803    def __init__(self, isRelaxed=True, name=None):
804        ModelVariation.__init__(self, name=name)
805        if isinstance(isRelaxed, bool):
806            self.isRelaxed = isRelaxed
807        else:
808            assert RelaxedModeConverter.IsCompatible(isRelaxed.lower())
809            self.isRelaxed = True
810
811    @staticmethod
812    def IsCompatible(value):
813        return value.lower() in ["relaxed"]
814
815    def SetToDefaultName(self):
816        self.name = "relaxed" if self.isRelaxed else "float"
817        return self
818
819    def TransformModel(self, model):
820        model.RelaxedExecution(self.isRelaxed)
821        return model
822
823# Convert data layout between "NHWC" amd "NCHW"
824class DataLayoutConverter(ModelVariation, ImplicitVariation):
825
826    def __init__(self, targetLayout="nchw", name=None):
827        ModelVariation.__init__(self, name=name)
828        self.targetLayout = targetLayout.lower()
829        assert DataLayoutConverter.IsCompatible(self.targetLayout)
830        self.perm = (0, 3, 1, 2) if self.targetLayout == "nchw" else (0, 2, 3, 1)
831        self.param = True if self.targetLayout == "nchw" else False
832
833    @staticmethod
834    def IsCompatible(value):
835        return value.lower() in ["nhwc", "nchw"]
836
837    def SetToDefaultName(self):
838        self.name = self.targetLayout
839        return self
840
841    def TransformOperand(self, op, arg=None):
842        if len(op.type.dimensions) == 4:
843            # To handle Internal operands
844            if op.value is not None and op.type.GetNumberOfElements() != 0:
845                op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm))
846            newDim = [op.type.dimensions[i] for i in self.perm]
847            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
848        elif len(op.type.dimensions) == 1 and len(op.value) == 4:
849            op.SetValueFromNumpy(op.GetValueAsNumpy()[list(self.perm)])
850        elif op.type.type == "BOOL":
851            op.SetValue(self.param)
852        else:
853            assert False, "%s not supported by DataLayoutConverter"%op
854        return op
855
856# Convert data by tansposing and removing axis
857class AxisConverter(ModelVariation):
858
859    def __init__(self, origin, target, dim, drop=[], name=None):
860        ModelVariation.__init__(self, name=name)
861        self.origin = origin
862        self.target = target
863        assert all(i >= -dim and i < dim for i in [self.origin, self.target])
864        self.dim = dim
865        self.perm = list(range(dim))
866        self.perm.insert(target if target >= 0 else target + dim, self.perm.pop(origin))
867        self.drop = [drop] if type(drop) is int else list(drop)
868        assert all(i >= -dim and i < dim for i in self.drop)
869        self.drop = [i if i >= 0 else i + dim for i in self.drop]
870        assert target not in self.drop and target + dim not in self.drop
871
872    def SetToDefaultName(self):
873        axis = self.target if self.target >= 0 else self.target + self.dim
874        axis -= sum(i < axis for i in self.drop)
875        neg = "" if self.target >= 0 else "_neg"
876        self.name = "dim%d_axis%d%s"%(self.dim - len(self.drop), axis, neg)
877        return self
878
879    def TransposeAxis(self, op):
880        if op.type.type == "INT32":
881            op.SetValue(self.target)
882        elif len(op.type.dimensions) == self.dim:
883            # To handle Internal operands
884            if op.value is not None:
885                op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm))
886            newDim = [op.type.dimensions[i] for i in self.perm]
887            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
888        else:
889            assert False, "%s not supported by AxisConverter"%op
890        return op
891
892    def RemoveAxis(self, op):
893        if op.type.type == "INT32":
894            if op.value[0] >= 0:
895                op.SetValue(op.value[0] - sum(i < op.value[0] for i in self.drop))
896            else:
897                op.SetValue(op.value[0] + sum(i > (op.value[0] + self.dim) for i in self.drop))
898        elif len(op.type.dimensions) == self.dim:
899            if op.value is not None:
900                val = op.GetValueAsNumpy()
901                for i in sorted(self.drop, reverse=True):
902                    val = np.take(val, 0, axis=i)
903                op.SetValueFromNumpy(val)
904            newDim = [op.type.dimensions[i] for i in range(self.dim) if i not in self.drop]
905            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
906        else:
907            assert False, "%s not supported by AxisConverter"%op
908        return op
909
910    def TransformOperand(self, op, arg=None):
911        op = self.TransposeAxis(op)
912        op = self.RemoveAxis(op)
913        return op
914
915# Convert Output based on activation
916class ActivationConverter(ModelVariation, ImplicitVariation):
917    # (Enum, low, high)
918    actMap = {
919        "none": (0, None, None),
920        "relu": (1, 0.0, None),
921        "relu1": (2, -1.0, 1.0),
922        "relu6": (3, 0.0, 6.0),
923    }
924    def __init__(self, act="relu", name=None):
925        ModelVariation.__init__(self, name=name)
926        self.act = act.lower()
927        assert ActivationConverter.IsCompatible(self.act)
928        self.enum = ActivationConverter.actMap[self.act][0]
929        self.low = ActivationConverter.actMap[self.act][1]
930        self.high = ActivationConverter.actMap[self.act][2]
931
932    @staticmethod
933    def IsCompatible(value):
934        return value.lower() in ActivationConverter.actMap.keys()
935
936    def SetToDefaultName(self):
937        self.name = self.act
938        return self
939
940    def TransformOperand(self, op, arg=None):
941        if op.type.type == "INT32": # activation enum
942            return op.SetValue(self.enum)
943        else:
944            assert isinstance(op, Output)
945            v = op.GetValueAsNumpy()
946            if self.low is not None:
947                low = Quantize(self.low, op.type)
948                v = np.maximum(v, low)
949            if self.high is not None:
950                high = Quantize(self.high, op.type)
951                v = np.minimum(v, high)
952            return op.SetValueFromNumpy(v)
953
954# Convert all constant tensors as model inputs.
955class AllTensorsAsInputsConverter(ModelVariation):
956    supportsSubgraphs = True
957
958    def __init__(self, name=None):
959        ModelVariation.__init__(self, name=name)
960
961    def SetToDefaultName(self):
962        self.name = "all_tensors_as_inputs"
963        return self
964
965    def TransformModel(self, model):
966        if len(model.operations) != 1:
967            raise SkipVariation
968
969        # Find all constant tensors.
970        tensorParams = [
971            p for p in model.operands
972            if type(p) is Parameter and not p.type.IsScalar() and p.value is not None
973        ]
974        if not tensorParams:
975            raise SkipVariation
976
977        # Convert to model inputs.
978        model.UpdateEquivalentOperands([op.ConvertTo(Input) for op in tensorParams])
979        return model
980
981def CompatibleWithADD(op):
982    return (len(op.type.dimensions) <= 4 and
983            len(op.value) > 0 and
984            op.type.type in ["TENSOR_FLOAT32", "TENSOR_QUANT8_ASYMM",
985                             "TENSOR_FLOAT16", "TENSOR_QUANT8_ASYMM_SIGNED"])
986
987# Add a placeholder ADD operation before each model input to make it as an internal operand.
988class AllInputsAsInternalCoverter(ModelVariation):
989    supportsSubgraphs = True
990
991    def __init__(self, name=None):
992        ModelVariation.__init__(self, name=name)
993
994    def SetToDefaultName(self):
995        self.name = "all_inputs_as_internal"
996        return self
997
998    def TransformModel(self, model):
999        if len(model.operations) != 1:
1000            raise SkipVariation
1001
1002        # Find all input tensors that can be an output of the ADD operation.
1003        modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal]
1004        if not modelInputs:
1005            raise SkipVariation
1006
1007        # Make every input an output of a placeholder operation: input_new ADD placeholder = input.
1008        for op in modelInputs:
1009            newInput = op.ConvertTo(Input, name=op.name + "_new")
1010            placeholderParam = Parameter("placeholder",
1011                                         (op.type.type, [1], op.type.scale, op.type.zeroPoint),
1012                                         [op.type.zeroPoint])
1013            model.Operation("ADD", newInput, placeholderParam, 0).To(op)
1014
1015        # Convert to internal operands.
1016        model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelInputs])
1017        return model
1018
1019# Add a placeholder ADD operation after each model output to make it as an internal operand.
1020class AllOutputsAsInternalCoverter(ModelVariation):
1021    supportsSubgraphs = True
1022
1023    def __init__(self, name=None):
1024        ModelVariation.__init__(self, name=name)
1025
1026    def SetToDefaultName(self):
1027        self.name = "all_outputs_as_internal"
1028        return self
1029
1030    def TransformModel(self, model):
1031        if len(model.operations) != 1:
1032            raise SkipVariation
1033
1034        # Find all output tensors that can be an input to an ADD operation.
1035        modelOutputs = [o for o in model.GetOutputs() if CompatibleWithADD(o)]
1036        if not modelOutputs:
1037            raise SkipVariation
1038
1039        # Make every output an input of a placeholder operation: output ADD placeholder = output_new.
1040        for op in modelOutputs:
1041            newOutput = op.ConvertTo(Output, name=op.name + "_new")
1042            placeholderParam = Parameter("placeholder",
1043                                         (op.type.type, [1], op.type.scale, op.type.zeroPoint),
1044                                         [op.type.zeroPoint])
1045            model.Operation("ADD", op, placeholderParam, 0).To(newOutput)
1046
1047        # Convert to internal operands.
1048        model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelOutputs])
1049        return model
1050
1051# An example is always attached to a model, and could have multiple variations
1052class Example:
1053    examples = []
1054    versionOverrides = {}
1055
1056    def __init__(self, *args, model=None, name=None):
1057        self.model = Model.models[-1] if model is None else model
1058        self.name = name
1059        self.expectedMultinomialDistributionTolerance = 0
1060        self.expectFailure = False
1061        self.testLifeTimeVariation = True
1062        self.feedDicts = []
1063        for feedDict in args:
1064            if type(feedDict) is tuple or type(feedDict) is list:
1065                self.feedDicts.append(feedDict)
1066            elif type(feedDict) is dict:
1067                self.feedDicts.append((
1068                    {i: feedDict[i] for i in self.model.GetInputs()},
1069                    {o: feedDict[o] for o in self.model.GetOutputs()}
1070                ))
1071            else:
1072                assert False
1073        self.variations = []
1074        Example.examples.append(self)
1075
1076    @staticmethod
1077    def SetVersion(ver, *args):
1078        for name in args:
1079            Example.versionOverrides[name] = ver
1080
1081    # Main entrance of test generator
1082    @staticmethod
1083    def DumpAllExamples(DumpModel=None, model_fd=None,
1084                        DumpExample=None, example_fd=None,
1085                        DumpTest=None, test_fd=None):
1086        Example.CombineAllExamples()
1087        for example in Example.examples:
1088            example.Dump(DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd)
1089
1090    # Combine examples with the same model, same name, and same set of variations
1091    @staticmethod
1092    def CombineAllExamples():
1093        modelMap = {}
1094        newExamples = []
1095        for example in Example.examples:
1096            key = (example.model, example.name, tuple(tuple(e) for e in example.variations))
1097            if key in modelMap:
1098                modelMap[key].Combine(example)
1099            else:
1100                modelMap[key] = example
1101                newExamples.append(example)
1102        Example.examples = newExamples
1103
1104    def AddVariations(self, *args, includeDefault=True, defaultName=None):
1105        self.variations.append([DefaultVariation(defaultName)] if includeDefault else [])
1106        self.variations[-1].extend(ImplicitVariation.ImplicitConvertion(i) for i in args)
1107        return self
1108
1109    def AddNchw(self, *args, includeDefault=True, defaultName="nhwc"):
1110        var = DataLayoutConverter("nchw").Identify(args)
1111        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1112        return self
1113
1114    def AddRelaxed(self, isRelaxed=True, includeDefault=True, defaultName=None):
1115        var = RelaxedModeConverter(isRelaxed)
1116        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1117        return self
1118
1119    def AddRelu(self, *args, includeDefault=True, defaultName=None):
1120        var = ActivationConverter("relu").Identify(args)
1121        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1122        return self
1123
1124    def AddAllActivations(self, *args):
1125        var = [ActivationConverter(i).Identify(args)
1126            for i in sorted(ActivationConverter.actMap.keys())]
1127        self.AddVariations(*var, includeDefault=False)
1128        return self
1129
1130    def GuessOriginalAxisAndDim(self, *args):
1131        origin = None
1132        dim = None
1133        for arg in args:
1134            if arg.type.type == "INT32":
1135                origin = arg.value[0]
1136            else:
1137                if dim is None:
1138                    dim = len(arg.type.dimensions)
1139                else:
1140                    assert dim == len(arg.type.dimensions)
1141        assert dim is not None
1142        origin = dim - 1 if origin is None else origin
1143        origin = origin + dim if origin < 0 else origin
1144        return origin, dim
1145
1146    def AddAxis(self, axis, *args, includeDefault=True, defaultName=None):
1147        origin, dim = self.GuessOriginalAxisAndDim(*args)
1148        axis = [axis] if type(axis) is int else list(axis)
1149        var = [AxisConverter(origin, a, dim).Identify(args) for a in axis]
1150        self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName)
1151        return self
1152
1153    def AddAllPositiveAxis(self, *args):
1154        origin, dim = self.GuessOriginalAxisAndDim(*args)
1155        var = [AxisConverter(origin, a, dim).Identify(args) for a in range(dim)]
1156        self.AddVariations(*var, includeDefault=False)
1157        return self
1158
1159    def AddAllAxis(self, *args):
1160        origin, dim = self.GuessOriginalAxisAndDim(*args)
1161        var = [AxisConverter(origin, a, dim).Identify(args) for a in range(-dim, dim)]
1162        self.AddVariations(*var, includeDefault=False)
1163        return self
1164
1165    def AddDims(self, dims, *args, includeDefault=True, defaultName=None):
1166        origin, dim = self.GuessOriginalAxisAndDim(*args)
1167        dims = [dims] if type(dims) is int else list(dims)
1168        drop = list(range(dim))
1169        drop.pop(origin)
1170        var = [AxisConverter(origin, origin, dim, drop[0:(dim-i)]).Identify(args) for i in dims]
1171        self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName)
1172        return self
1173
1174    def AddAllDims(self, *args):
1175        origin, dim = self.GuessOriginalAxisAndDim(*args)
1176        drop = list(range(dim))
1177        drop.pop(origin)
1178        var = [AxisConverter(origin, origin, dim, drop[0:i]).Identify(args) for i in range(dim)]
1179        self.AddVariations(*var, includeDefault=False)
1180        return self
1181
1182    def AddAllDimsAndPositiveAxis(self, *args):
1183        origin, dim = self.GuessOriginalAxisAndDim(*args)
1184        var = [AxisConverter(origin, j, dim, range(i)).Identify(args) \
1185                for i in range(dim) for j in range(i, dim)]
1186        self.AddVariations(*var, includeDefault=False)
1187        return self
1188
1189    def AddAllDimsAndAxis(self, *args):
1190        origin, dim = self.GuessOriginalAxisAndDim(*args)
1191        var = [AxisConverter(origin, k, dim, range(i)).Identify(args) \
1192                for i in range(dim) for j in range(i, dim) for k in [j, j - dim]]
1193        self.AddVariations(*var, includeDefault=False)
1194        return self
1195
1196    def Combine(self, other):
1197        assert self.model is other.model, "Only examples targetting the same model can be combined"
1198        assert tuple(self.variations) == tuple(other.variations), \
1199            "Only examples with the same set of variations can be combined"
1200        assert self.name == other.name, "Only examples with the same name can be combined"
1201        self.feedDicts.extend(other.feedDicts)
1202        return self
1203
1204    def Dump(self, DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd):
1205        if self.testLifeTimeVariation and len(self.model.operations) == 1 and \
1206                self.expectedMultinomialDistributionTolerance == 0:
1207            self.AddVariations(AllTensorsAsInputsConverter())
1208            self.AddVariations(AllInputsAsInternalCoverter())
1209        [v.SetToDefaultName() for vs in self.variations for v in vs if v.name is None]
1210
1211        for feedDict in self.feedDicts:
1212            self.model.Feed(feedDict)
1213            for variationList in itertools.product(*self.variations):
1214                modelOrigin = self.model
1215                self.model = copy.deepcopy(self.model)
1216
1217                # Apply variations
1218                try:
1219                    for variation in variationList:
1220                        self.model = variation.ApplyTo(self.model)
1221                except SkipVariation:
1222                    self.model = modelOrigin
1223                    continue
1224
1225                # Concat names for test and examples
1226                varNames = [v.name for v in variationList]
1227                self.testName = NamedTest(FileNames.specName, self.model.name, self.name, *varNames)
1228                self.examplesName = GlobalVariable("test_model", self.model.name, self.name,
1229                                                   *varNames)
1230                if str(self.testName) in Example.versionOverrides:
1231                    self.model.IntroducedIn(Example.versionOverrides[str(self.testName)])
1232                self.model.Compile()
1233
1234                # Dump files
1235                if DumpExample is not None and example_fd is not None:
1236                    DumpExample(self, example_fd)
1237                if DumpTest is not None and test_fd is not None:
1238                    DumpTest(self, test_fd)
1239
1240                # Restore model before variation
1241                self.model = modelOrigin
1242        return self
1243
1244    # Specifies the RANDOM_MULTINOMIAL distribution tolerance.
1245    # If set to greater than zero, the input is compared as log-probabilities
1246    # to the output and must be within this tolerance to pass.
1247    def WithMultinomialDistributionTolerance(self, expectedTolerance):
1248        assert self.expectFailure is False
1249        self.expectedMultinomialDistributionTolerance = expectedTolerance
1250        return self
1251
1252    # Specifies that this example is expected to fail during compilation or execution.
1253    def ExpectFailure(self):
1254        assert self.expectedMultinomialDistributionTolerance == 0
1255        self.expectFailure = True
1256        return self
1257
1258    def DisableLifeTimeVariation(self):
1259        self.testLifeTimeVariation = False
1260        return self
1261
1262class FileNames:
1263    specFiles = []
1264    specNames = []
1265    exampleFiles = []
1266    specFile = ""
1267    specName = ""
1268    exampleFile = ""
1269    version = ""
1270    fileIndex = 0
1271
1272    @staticmethod
1273    def InitializeFileLists(spec, example):
1274        # get all spec files and target files
1275        if os.path.isfile(spec):
1276            FileNames.specFiles = [os.path.abspath(spec)]
1277        elif os.path.isdir(spec):
1278            FileNames.specFiles = sorted([os.path.abspath(os.path.join(spec, f))
1279                for f in os.listdir(spec) if f.endswith(".mod.py")])
1280        else:
1281            assert False, "%s is neither a file or a directory"%spec
1282        FileNames.specNames = [re.sub(r"\..*", "", os.path.basename(f))
1283            for f in FileNames.specFiles]
1284        FileNames.exampleFiles = FileNames.ParseTargetFiles(example, ".example.cpp")
1285
1286    @staticmethod
1287    def ParseTargetFiles(arg, ext):
1288        numFiles = len(FileNames.specFiles)
1289        if arg is None:
1290            return [None] * numFiles
1291        absPath = os.path.abspath(arg)
1292        if os.path.isdir(arg):
1293            target = [os.path.join(absPath, f + ext) for f in FileNames.specNames]
1294        elif arg == "-":
1295            target = ["-"] * numFiles
1296        else:
1297            target = [absPath] * numFiles
1298        return target
1299
1300    @staticmethod
1301    def NextFile():
1302        if FileNames.fileIndex >= len(FileNames.specFiles):
1303            return False
1304        FileNames.specFile = FileNames.specFiles[FileNames.fileIndex]
1305        FileNames.specName = FileNames.specNames[FileNames.fileIndex]
1306        FileNames.exampleFile = FileNames.exampleFiles[FileNames.fileIndex]
1307        FileNames.fileIndex += 1
1308        NamedObject.existingNames = set()
1309        NamedVariable.existingNames = set()
1310        NamedTest.existingNames = set()
1311        Type.typesMap = dict()
1312        Model.models = list()
1313        Example.examples = list()
1314        Configuration.use_shm_for_weights = False
1315
1316        # Extract version from absolute file path.
1317        versionMatch = re.findall(r"/V\d_\d/", FileNames.specFile)
1318        if len(versionMatch) == 1:
1319            FileNames.version = versionMatch[0].strip('/')
1320        else:
1321            FileNames.version = None
1322        return True
1323
1324class Configuration:
1325    use_shm_for_weights = False
1326    hook_mode = False
1327
1328    @staticmethod
1329    def useSHM():
1330        return Configuration.use_shm_for_weights
1331
1332def GetTestGeneratorMTime():
1333    tgFiles = ['test_generator.py', 'example_generator.py']
1334    tgDir = os.path.dirname(__file__)
1335    return max(os.path.getmtime(os.path.join(tgDir, filename))
1336               for filename in tgFiles)
1337
1338def MightNeedRegeneration():
1339    specTime = os.path.getmtime(FileNames.specFile)
1340    tgTime = GetTestGeneratorMTime()
1341    return not os.path.exists(FileNames.exampleFile) or \
1342           os.path.getmtime(FileNames.exampleFile) <= max(specTime, tgTime)
1343
1344def Read(filename):
1345    with open(filename) as reader:
1346        return reader.read()
1347
1348def AtomicWrite(filename, data):
1349    # os.replace(src, dest) may fail if src and dest are on diffrent
1350    # filesystems.
1351    tempFile = filename + '.tmp'
1352    try:
1353        with open(tempFile, 'w') as writer:
1354            writer.write(data)
1355        os.replace(tempFile, filename)
1356        tempFile = None
1357    finally:
1358        if tempFile is not None and os.path.exists(tempFile):
1359            os.remove(tempFile)
1360
1361def GetExecScope():
1362    return dict(
1363        ActivationConverter=ActivationConverter,
1364        AllInputsAsInternalCoverter=AllInputsAsInternalCoverter,
1365        AllOutputsAsInternalCoverter=AllOutputsAsInternalCoverter,
1366        AllTensorsAsInputsConverter=AllTensorsAsInputsConverter,
1367        BoolScalar=BoolScalar,
1368        Configuration=Configuration,
1369        DataLayoutConverter=DataLayoutConverter,
1370        DataTypeConverter=DataTypeConverter,
1371        Example=Example,
1372        Float16Scalar=Float16Scalar,
1373        Float32Scalar=Float32Scalar,
1374        Float32Vector=Float32Vector,
1375        IgnoredOutput=IgnoredOutput,
1376        Input=Input,
1377        Int32Scalar=Int32Scalar,
1378        Int32Vector=Int32Vector,
1379        Internal=Internal,
1380        Model=Model,
1381        Operand=Operand,
1382        Output=Output,
1383        Parameter=Parameter,
1384        RelaxedModeConverter=RelaxedModeConverter,
1385        SubgraphReference=SubgraphReference,
1386        SymmPerChannelQuantParams=SymmPerChannelQuantParams)
1387
1388def ArgumentParser():
1389    parser = argparse.ArgumentParser()
1390    parser.add_argument("spec", help="the spec file or directory")
1391    parser.add_argument("--hook", help="hook mode", action='store_true')
1392    return parser
1393
1394def ParseArgs(parser):
1395    args = parser.parse_args()
1396    Configuration.hook_mode = args.hook
1397    return args
1398
1399def Run(InitializeFiles=None, DumpExample=None):
1400    exec_scope = GetExecScope()
1401    while FileNames.NextFile():
1402        try:
1403            if not MightNeedRegeneration():
1404                continue
1405            exec(Read(FileNames.specFile), exec_scope)
1406            example_buf = io.StringIO() if FileNames.exampleFile else None
1407            InitializeFiles(example_fd=example_buf)
1408            Example.DumpAllExamples(DumpExample=DumpExample, example_fd=example_buf)
1409            if FileNames.exampleFile is None:
1410                continue
1411            if Configuration.hook_mode and (not os.path.exists(FileNames.exampleFile) or
1412                                            Read(FileNames.exampleFile) != example_buf.getvalue()):
1413                print(('\n{filename} is out of date. '
1414                        'Please run {generate_all_tests_sh} before uploading.\n').format(
1415                                filename=FileNames.exampleFile,
1416                                generate_all_tests_sh=os.path.abspath(os.path.join(
1417                                        os.path.dirname(__file__), '..', '..', 'runtime', 'test',
1418                                        'specs', 'generate_all_tests.sh'))))
1419                sys.exit(1)
1420            AtomicWrite(FileNames.exampleFile, example_buf.getvalue())
1421        except Exception:
1422            traceback.print_exc()
1423            sys.exit("Exception raised when processing {}".format(FileNames.specFile))
1424