• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""NN model compiler
18
19Contain classes definition and utilify functions for compiling models and
20examples into NDK-based CTS and VTS unit tests.
21
22Used by cts_generator.py, vts_generator.py, and slicing.py
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28import argparse
29import copy
30from functools import reduce
31import itertools
32import math
33import os
34import re
35import struct
36import sys
37import contextlib
38import pprint
39import numpy as np
40
41def GetJointStr(l, sep=", ", method=str):
42    return sep.join([method(i) for i in l])
43
44# Print in C float literal format
45def PrettyPrintAsFloat(x):
46    s = str(float(x))
47    if s.find(".") >= 0 or s.find("e") >= 0:
48        return s + "f"
49    else:
50        return s + ".0f"
51
52# Transform from original type to float32
53def Dequantize(v, ty):
54    v -= ty.zeroPoint
55    if ty.scale != 0:
56        v *= ty.scale
57    if isinstance(ty.extraParams, SymmPerChannelQuantParams):
58        v *= ty.extraParams.GetScalesBroadcastArray(ty.dimensions)
59    return v
60
61# Transform float32 to target data type
62def Quantize(v, ty):
63    if ty.scale != 0:
64        v /= ty.scale
65    if isinstance(ty.extraParams, SymmPerChannelQuantParams):
66        v = v / ty.extraParams.GetScalesBroadcastArray(ty.dimensions)
67    v += ty.zeroPoint
68    if not ty.IsFloat():
69        v = np.round(v)
70        v = int(v) if np.isscalar(v) else v.astype(int)
71    if ty.type == "TENSOR_QUANT8_ASYMM":
72        v = np.minimum(np.maximum(v, 0), 255)
73    elif ty.type == "TENSOR_QUANT16_ASYMM":
74        v = np.minimum(np.maximum(v, 0), 65535)
75    elif ty.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL":
76        v = np.minimum(np.maximum(v, -127), 127)
77    elif ty.type == "UINT32":
78        v = np.maximum(v, 0)
79    return v
80
81@contextlib.contextmanager
82def SmartOpen(filename=None, mode="w"):
83    if filename and filename != '-':
84        fh = open(filename, mode)
85    else:
86        fh = sys.stdout
87
88    try:
89        yield fh
90    finally:
91        if fh is not sys.stdout:
92            fh.close()
93
94# Tracking objects inside a model with a unique name
95class NamedObject:
96    existingNames = set()
97
98    def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False):
99        name = GetJointStr([i for i in args if i is not None and i != ""], sep=sep)
100        if skipRenaming:
101            self.name = name
102            return
103        # make the name unique by renaming with a suffix number
104        uniqueName = name if showZero is False else name + sep + str(startsFrom)
105        while uniqueName in self.__class__.existingNames:
106            startsFrom += 1
107            uniqueName = name + sep + str(startsFrom)
108        self.__class__.existingNames.add(uniqueName)
109        self.name = uniqueName
110
111    def __str__(self):
112        return self.name
113    __repr__ = __str__
114
115    # Since names are unique, objects with the same name are considered equal
116    def __eq__(self, other):
117        return isinstance(other, NamedObject) and self.name == other.name
118
119    def __ne__(self, other):
120        return not self.__eq__(other)
121
122    def __hash__(self):
123        return hash(self.name)
124
125    def __lt__(self, other):
126        return self.name < other.name
127
128# Types, operands should all have a unique name since they share the same namespace
129class NamedVariable(NamedObject):
130    existingNames = set()
131    def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False):
132        NamedObject.__init__(self, *args, sep=sep, showZero=showZero,
133            startsFrom=startsFrom, skipRenaming=skipRenaming)
134
135# Global variables in the spec namespace such as CreateModel, is_ignored, and examples
136class GlobalVariable(NamedVariable):
137    def __init__(self, *args, skipRenaming=False):
138        NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming)
139
140# Each test should have a unique name, but will not conflict with variables
141class NamedTest(NamedObject):
142    existingNames = set()
143    def __init__(self, *args, startsFrom=0, skipRenaming=False):
144        NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming)
145
146class Type(NamedVariable):
147    typesMap = dict()
148    typeLookup = {
149        "INT32": "int32_t",
150        "UINT32": "uint32_t",
151        "FLOAT32": "float",
152        "FLOAT16": "_Float16",
153        "TENSOR_INT32": "int32_t",
154        "TENSOR_FLOAT16": "_Float16",
155        "TENSOR_FLOAT32": "float",
156        "TENSOR_QUANT8_ASYMM": "uint8_t",
157        "TENSOR_QUANT8_SYMM": "int8_t",
158        "BOOL": "bool8",
159        "TENSOR_QUANT16_ASYMM": "uint16_t",
160        "TENSOR_QUANT16_SYMM": "int16_t",
161        "TENSOR_BOOL8": "bool8",
162        "TENSOR_QUANT8_SYMM_PER_CHANNEL": "int8_t",
163#     "OEM_SCALAR": this is service-defined.
164        "TENSOR_OEM_BYTE": "uint8_t",
165    }
166
167    # types are named as "type0", "type1", ...
168    def __init__(self, vt, dimensions, scale, zeroPoint, name="type", skipRenaming=False,
169                 extraParams=None):
170        NamedVariable.__init__(self, name, sep="", showZero=True, skipRenaming=skipRenaming)
171        self.type = vt
172        self.dimensions = dimensions
173        self.scale = float(scale)
174        self.zeroPoint = int(zeroPoint)
175        self.extraParams = extraParams
176
177    # Factory for Type object, only create a new Type if requested type does
178    # not have a match with all existing types
179    @staticmethod
180    def GetType(vt, dimensions, scale=0, zeroPoint=0, extraParams=None):
181        key = ",".join([vt, str(dimensions), str(scale), str(zeroPoint), str(extraParams)])
182        if key not in Type.typesMap:
183            Type.typesMap[key] = Type(vt, dimensions, scale, zeroPoint, extraParams=extraParams)
184        return Type.typesMap[key]
185
186    @staticmethod
187    def GetAllTypes():
188        # sort to ensure a stable order when dumping the code
189        return sorted(Type.typesMap.values())
190
191    # For backward-compatibility
192    @staticmethod
193    def GetTypeFromString(vt, shape, extraParams=None):
194        dimensions, scale, zeroPoint = Type.GetParsedShape(shape)
195        scale = float(scale)
196        zeroPoint = int(zeroPoint)
197        return Type.GetType(vt, dimensions, scale, zeroPoint, extraParams)
198
199    # For backward-compatibility
200    @staticmethod
201    def GetParsedShape(shape):
202        # Parse shape
203        if (shape != "" and shape != "{}"):
204            left, sep, right = shape.partition('{')
205            real_shape, sep, right = right.partition('}')
206            shape = [int(x) for x in real_shape.split(",")]
207            # left now looks like "0.0f, 127.5f, "
208            scale, sep, zero_point = right.rpartition(',')
209            if scale == "":
210                if zero_point == "":
211                    return shape, "0", "0"
212                return shape, zero_point, "0"
213            left, sep, scale = scale.partition(',')
214            return shape, scale.replace("f", ""), zero_point
215        else:
216            return [], "0", "0"
217
218    def GetNumberOfElements(self):
219        return reduce(lambda x,y: x*y, self.dimensions, 1)
220
221    def GetCppTypeString(self):
222        return Type.typeLookup[self.type]
223
224    def IsFloat(self):
225        return self.GetCppTypeString() in ["float", "_Float16"]
226
227    def IsBool(self):
228        return self.GetCppTypeString() == "bool8"
229
230    def GetElementByteSize(self):
231        cppTypeString = self.GetCppTypeString()
232        if cppTypeString in ["uint8_t", "int8_t", "bool8"]:
233            return 1
234        elif cppTypeString in ["int16_t", "uint16_t", "_Float16"]:
235            return 2
236        else:
237            return 4
238
239    def GetByteSize(self):
240        return self.GetElementByteSize() * self.GetNumberOfElements()
241
242    def GetDimensionsString(self):
243        return "{" + GetJointStr(self.dimensions) + "}"
244
245    def GetSignatureTuple(self):
246        return (self.type, self.dimensions, self.scale, self.zeroPoint)
247
248    # For backward-compatibility with slicing.py
249    def GetRawShape(self):
250        if self.scale == 0 and self.zeroPoint == 0:
251            return self.GetDimensionsString()
252        else:
253            return GetJointStr([self.GetDimensionsString(), self.scale, self.zeroPoint])
254
255    def ToUnspecifiedDim(self):
256        return Type.GetType(self.type, [0] * len(self.dimensions), self.scale, self.zeroPoint)
257
258# To track implicitly convertible parameter types
259class ImplicitParameter():
260    @staticmethod
261    def ImplicitConvertion(value):
262        if isinstance(value, Operand):
263            return value
264        for implicitType in ImplicitParameter.__subclasses__():
265            if implicitType.IsCompatible(value):
266                return implicitType("param", value)
267        assert False, "%s not supported for implicit parameter"%value
268
269
270# ExtraParams with per-channel quantization.
271class SymmPerChannelQuantParams():
272  def __init__(self, channelDim, scales, hide = False):
273    self.channelDim = channelDim
274    self.scales = scales
275    self.hide = hide
276
277  def GetScalesBroadcastArray(self, dimensions):
278    bshape = [1] * len(dimensions)
279    bshape[self.channelDim] = len(self.scales)
280    return np.array(self.scales).reshape(bshape)
281
282  def GetConstructor(self):
283    return "SymmPerChannelQuantParams({%s},%d)" % (
284        ", ".join(str(x) + "f" for x in self.scales), self.channelDim)
285
286  def GetVtsSetter(self):
287    return "channelQuant"
288
289  def GetVtsConstructor(self):
290    return "SymmPerChannelQuantParams{.scales={%s}, .channelDim=%d}" % (
291        ", ".join(str(x) + "f" for x in self.scales), self.channelDim)
292
293
294# An operand that can be fed into operations. Also, an operand is always
295# declared before operations.
296class Operand(NamedVariable):
297
298    def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None):
299        NamedVariable.__init__(self, name, sep="", skipRenaming=skipRenaming)
300        if type(opType) is str:
301            self.type = Type.GetTypeFromString(opType, value, extraParams)
302            value = backward
303        else:
304            self.type = Type.GetType(*opType, extraParams=extraParams)
305        self.SetValue(value)
306        self.dimensions = self.type.dimensions
307        self.lifetime = "TEMPORARY_VARIABLE"
308        self.ins = []
309        self.outs = []
310
311    def SetValue(self, value):
312        self.value = value if type(value) is list or type(value) is tuple else [value]
313        return self
314
315    def SetValueFromNumpy(self, value):
316        self.value = value.flatten().tolist()
317        return self
318
319    def GetValueAsNumpy(self):
320        return np.array(self.value).reshape(self.type.dimensions)
321
322    # Print value as cpp-style list initialization
323    def GetListInitialization(self):
324        assert self.value is not None, \
325            "Trying to print operand %s with None value"%(str(self))
326        if self.type.IsFloat():
327            return "{%s}"%(GetJointStr(self.value, method=PrettyPrintAsFloat))
328        elif self.type.IsBool():
329            return "{%s}"%(GetJointStr(self.value, method=lambda v: "true" if v else "false"))
330        else:
331            return "{%s}"%(GetJointStr(self.value, method=lambda x: str(int(x))))
332
333    def ToUnspecifiedDim(self):
334        self.dimensions = self.type.dimensions
335        self.type = self.type.ToUnspecifiedDim()
336
337# Base class of user-defined input/output operand
338class InOut(Operand):
339
340    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
341        Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, extraParams=extraParams)
342        self.lifetime = "MODEL_INPUT"
343        self.index = 0
344
345    def Feed(self, value):
346        self.SetValue(value[self] if type(value) is dict else value)
347        return self
348
349    def GetListInitialization(self):
350        return "{%d, %s}"%(self.index, super().GetListInitialization())
351
352# A user-declared input operand
353class Input(InOut):
354    def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None):
355        InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams)
356        self.lifetime = "MODEL_INPUT"
357
358# A user-declared output operand
359class Output(InOut):
360    def __init__(self, name, opType, backward=None, skipRenaming=False):
361        InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming)
362        self.lifetime = "MODEL_OUTPUT"
363
364# An output that we don't want to compare the results
365class IgnoredOutput(Output):
366    def __init__(self, name, opType, backward=None, skipRenaming=False):
367        Output.__init__(self, name, opType, backward, skipRenaming=skipRenaming)
368        self.lifetime = "MODEL_OUTPUT"
369    def Feed(self, value):
370        numElements = reduce(lambda x,y: x*y, self.dimensions, 1)
371        self.value = [0 for x in range(numElements)]
372        return self
373
374# An explicitly declared parameter
375class Parameter(Operand):
376    def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None):
377        Operand.__init__(self, name, opType, value, backward, skipRenaming=skipRenaming,
378                         extraParams=extraParams)
379        self.initializer = NamedVariable(str(self) + "_init")
380        self.lifetime = "CONSTANT_REFERENCE" if Configuration.useSHM() else "CONSTANT_COPY"
381
382# A shortcut for parameters of INT32
383class Int32Scalar(Parameter, ImplicitParameter):
384    def __init__(self, name, value):
385        Parameter.__init__(self, name, ("INT32", []), int(value))
386    @staticmethod
387    def IsCompatible(value):
388        return type(value) is int
389
390# A shortcut for parameters of FLOAT16
391class Float16Scalar(Parameter, ImplicitParameter):
392    def __init__(self, name, value):
393        Parameter.__init__(self, name, ("FLOAT16", []), float(value))
394    @staticmethod
395    def IsCompatible(value):
396        return False
397
398# A shortcut for parameters of FLOAT32
399class Float32Scalar(Parameter, ImplicitParameter):
400    def __init__(self, name, value):
401        Parameter.__init__(self, name, ("FLOAT32", []), float(value))
402    @staticmethod
403    def IsCompatible(value):
404        return type(value) is float
405
406# A shortcut for parameters of BOOL
407class BoolScalar(Parameter, ImplicitParameter):
408    def __init__(self, name, value):
409        Parameter.__init__(self, name, ("BOOL", []), bool(value))
410    @staticmethod
411    def IsCompatible(value):
412        return type(value) is bool
413
414# A shortcut for parameter of 1-D TENSOR_INT32
415class Int32Vector(Parameter, ImplicitParameter):
416    def __init__(self, name, value):
417        Parameter.__init__(self, name, ("TENSOR_INT32", [len(value)]), [int(v) for v in value])
418    @staticmethod
419    def IsCompatible(value):
420        if type(value) is not list and type(value) is not tuple:
421            return False
422        return all(type(i) is int for i in value)
423
424# A shortcut for parameter of 1-D TENSOR_FLOAT32
425class Float32Vector(Parameter, ImplicitParameter):
426    def __init__(self, name, value):
427        Parameter.__init__(self, name, ("TENSOR_FLOAT32", [len(value)]), [float(v) for v in value])
428    @staticmethod
429    def IsCompatible(value):
430        if type(value) is not list and type(value) is not tuple:
431            return False
432        return all(type(i) is float for i in value)
433
434# An explicitly declared intermediate result
435class Internal(Operand):
436    def __init__(self, name, opType, backward=None, skipRenaming=False):
437        Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming)
438        self.lifetime = "TEMPORARY_VARIABLE"
439
440# An operation in a model, does not need a name
441class Operation:
442
443    def __init__(self, optype, ins, outs):
444        self.optype = optype
445        self.SetInputs(ins)
446        self.SetOutputs(outs)
447
448    # for the ease of debugging
449    def __str__(self):
450        insString = GetJointStr(self.ins)
451        outsString = GetJointStr(self.outs)
452        return "Operation %s: [%s] -> [%s]"%(self.optype, insString, outsString)
453    __repr__ = __str__
454
455    def SetInputs(self, ins):
456        self.ins = [ImplicitParameter.ImplicitConvertion(i) for i in ins]
457        return self
458
459    def SetOutputs(self, outs):
460        self.outs = list(outs)
461        return self
462
463    # For backward-compatibility with slicing.py
464    # Get Python-ish dump for the op
465    def PyDefinition(self):
466        py_op_string = """Operation("{optype}", {inputs}).To({outputs})"""
467        inputs = [str(x) for x in self.ins]
468        inputs = ", ".join(inputs)
469        assert len(self.outs) <= 1
470        outputs = str(self.outs[0])
471        ops = {"optype": self.optype, "inputs": inputs, "outputs": outputs}
472        return py_op_string.format(**ops)
473
474# Main interface
475class Model:
476    models = list()
477
478    def __init__(self, name=None):
479        self.name = name
480        self.operations = []
481        self.operands = []
482        self.isRelaxed = False
483        self.compiled = False
484        self.dumped = False
485        self.hasDynamicOutputShape = False
486        self.version = FileNames.version
487        Model.models.append(self)
488
489    def WithSuffix(self, *args):
490        self.createFunctionName = GlobalVariable("CreateModel", self.name, *args)
491        self.createTestFunctionName = GlobalVariable("createTestModel", self.name, *args)
492        self.isIgnoredFunctionName = GlobalVariable("is_ignored", self.name, *args)
493        return self
494
495    def AddOperation(self, operation):
496        self.operations.append(operation)
497        for i in operation.ins:
498            if i not in self.operands:
499                self.operands.append(i)
500        for o in operation.outs:
501            if o not in self.operands:
502                self.operands.append(o)
503        return self
504
505    def Operation(self, op_name, *args):
506        return self.AddOperation(Operation(op_name, args, []))
507
508    def To(self, *args):
509        assert len(self.operations) > 0
510        if type(args[0]) is tuple or type(args[0]) is list:
511            outs = args[0]
512        else:
513            outs = args
514        self.operations[-1].SetOutputs(outs)
515        for o in outs:
516            if o not in self.operands:
517                self.operands.append(o)
518        return self
519
520    def RelaxedExecution(self, isRelaxed):
521        self.isRelaxed = isRelaxed
522        return self
523
524    def TestDynamicOutputShape(self, hasDynamicOutputShape):
525        self.hasDynamicOutputShape = hasDynamicOutputShape
526        return self
527
528    # Sets the version of the model in compliance tests. Set to None to disable the test.
529    def IntroducedIn(self, ver):
530        self.version = ver
531        return self
532
533    def GetTypes(self):
534        return sorted(list(set(op.type for op in self.operands)))
535
536    def GetInputs(self):
537        return [i for i in self.operands if isinstance(i, Input)]
538
539    def GetOutputs(self):
540        return [o for o in self.operands if isinstance(o, Output)]
541
542    def GetInputsIndex(self):
543        return [i for i,op in enumerate(self.operands) if isinstance(op, Input)]
544
545    def GetOutputsIndex(self):
546        return [o for o,op in enumerate(self.operands) if isinstance(op, Output)]
547
548    def GetIndexOfOperands(self, operands):
549        return [self.operands.index(i) for i in operands]
550
551    def GetIgnoredOutputs(self):
552        return [o for o in self.operands if isinstance(o, IgnoredOutput)]
553
554    def GetParameters(self):
555        return [p for p in self.operands if isinstance(p, Parameter)]
556
557    def GetEquivalentOperands(self, targets):
558        return [self.operands[self.operands.index(t)] for t in targets]
559
560    def UpdateEquivalentOperands(self, targets):
561        for t in targets:
562            self.operands[self.operands.index(t)] = t
563        return self
564
565    def SetInputAndOutputIndex(self):
566        for ind, i in enumerate(self.GetInputs()):
567            i.index = ind
568        for ind, o in enumerate(self.GetOutputs()):
569            o.index = ind
570        return self
571
572    def SetOperandInsAndOuts(self):
573        for op in self.operands:
574            op.ins = list()
575            op.outs = list()
576        for op in self.operations:
577            op.ins = self.GetEquivalentOperands(op.ins)
578            op.outs = self.GetEquivalentOperands(op.outs)
579            for i in op.ins:
580                i.outs.append(op)
581            for o in op.outs:
582                o.ins.append(op)
583        return self
584
585    def TopologicalSortHelper(self, op, deps, visited):
586        if op in visited:
587            assert op not in deps, "Cycle detected in the graph"
588        else:
589            visited.add(op)
590            for i in deps[op]:
591                self.TopologicalSortHelper(i, deps, visited)
592            self.operations.append(op)
593            deps.pop(op)
594
595    # Topological sort of the operations, and detect if there is a cycle is the graph
596    def TopologicalSort(self):
597        deps = {op: list() for op in self.operations}
598        [deps[o].append(i) for op in self.operands for o in op.outs for i in op.ins]
599        operations = self.operations.copy()
600        self.operations = []
601        visited = set()
602        for op in operations:
603            self.TopologicalSortHelper(op, deps, visited)
604
605    def SetOutputUnspecified(self):
606        for op in self.operands:
607            op.dimensions = op.type.dimensions
608        if self.hasDynamicOutputShape:
609            for op in self.GetOutputs():
610                op.ToUnspecifiedDim()
611        return self
612
613    def Compile(self):
614        if self.compiled:
615            return self
616        self.SetInputAndOutputIndex()
617        self.SetOperandInsAndOuts()
618        self.TopologicalSort()
619        self.SetOutputUnspecified()
620        # Do not check compliance for relaxed mode and dynamic output shape tests.
621        if self.isRelaxed or self.hasDynamicOutputShape:
622            self.IntroducedIn(None)
623        self.compiled = True
624        return self
625
626# To track implicitly convertible variation types
627class ImplicitVariation:
628    @staticmethod
629    def ImplicitConvertion(value):
630        if isinstance(value, ModelVariation):
631            return value
632        for implicitType in ImplicitVariation.__subclasses__():
633            value = value if type(value) is tuple or type(value) is list else [value]
634            if implicitType.IsCompatible(value[0]):
635                var = implicitType(value[0])
636                if len(value) > 1:
637                    var.Identify(*value[1:])
638                return var
639        assert False, "%s not supported for implicit variation"%value[0]
640
641# The base class for model variations
642class ModelVariation:
643
644    def __init__(self, name=None):
645        self.targetOperands = {}
646        self.name = name
647
648    def ApplyToHelper(self, model, args, feedDicts, transform):
649        opVarList = []
650        for op in model.GetEquivalentOperands(sorted(args.keys())):
651            opVar = op
652            feedDictsVar = []
653            if isinstance(op, Input) or isinstance(op, Output):
654                for feedDict in feedDicts:
655                    op_tmp = copy.deepcopy(op)
656                    if op_tmp in feedDict[0]:
657                        opVar = transform(op_tmp.Feed(feedDict[0]), args[op_tmp])
658                    elif op_tmp in feedDict[1]:
659                        opVar = transform(op_tmp.Feed(feedDict[1]), args[op_tmp])
660                    else:
661                        assert False
662                    feedDictsVar.append(opVar.value)
663                assert type(op) == type(opVar), "Can not handle %s -> %s"%(type(op), type(opVar))
664            else:
665                opVar = transform(op, args[op])
666                # handle Parameter -> Input
667                if isinstance(opVar, Input) or isinstance(opVar, Output):
668                    feedDictsVar = [opVar.value] * len(feedDicts)
669            if isinstance(opVar, Input) or isinstance(opVar, Output):
670                for feedDict, feedDictVar in zip(feedDicts, feedDictsVar):
671                    if opVar in feedDict[1]:
672                        feedDict[1][opVar] = feedDictVar
673                    else:
674                        feedDict[0][opVar] = feedDictVar
675            opVarList.append(opVar)
676        return opVarList
677
678    # Make a deepcopy of the model and feedDicts, and apply the change
679    def ApplyTo(self, modelOrigin, feedDictsOrigin):
680        model, feedDicts = copy.deepcopy((modelOrigin, feedDictsOrigin))
681        model.compiled = False
682        model.dumped = False
683
684        if not self.targetOperands:
685            self.AutoIdentify(model)
686
687        # get transformed operands and update feedDicts
688        operandsVar = self.ApplyToHelper(
689            model, self.targetOperands, feedDicts, self.TransformOperand)
690
691        model = self.TransformModel(model)
692        model.UpdateEquivalentOperands(operandsVar)
693        return model, feedDicts
694
695    def IdentifyOperands(self, args=None):
696        if args is None:
697            return self
698        self.targetOperands = args if type(args) is dict else {i: None for i in args}
699        return self
700
701    def Identify(self, operandArgs=None, paramArgs=None):
702        self.IdentifyOperands(operandArgs)
703        return self
704
705    # Set variation to its default name
706    def SetToDefaultName(self):
707        self.name = ""
708        return self
709
710    # Automatically select the target operand list
711    def AutoIdentify(self, model):
712        return self
713
714    # Transform operands that are marked by IdentifyOperands()
715    def TransformOperand(self, op, arg=None):
716        return op
717
718    # Transform the model
719    def TransformModel(self, model):
720        return model
721
722# Default variation that does nothing
723class DefaultVariation(ModelVariation):
724
725    def __init__(self, name=None):
726        ModelVariation.__init__(self, name=name)
727
728# Convert operand data type
729class DataTypeConverter(ModelVariation, ImplicitVariation):
730
731    def __init__(self, targetType=None, name=None):
732        ModelVariation.__init__(self, name=name)
733        if targetType is not None:
734            assert DataTypeConverter.IsCompatible(targetType)
735        self.targetType = targetType
736
737    @staticmethod
738    def IsCompatible(value):
739        return value.lower() in ["float16", "int32"]
740
741    def SetToDefaultName(self):
742        if self.targetType is not None:
743            self.name = self.targetType.lower()
744            return self
745        # get all target types
746        targetTypes = list(zip(*self.targetOperands.values()))[0]
747        if "TENSOR_QUANT8_SYMM_PER_CHANNEL" in targetTypes:
748            self.name = "channelQuant8"
749        elif "TENSOR_QUANT8_ASYMM" in targetTypes:
750            self.name = "quant8"
751        elif "TENSOR_INT32" in targetTypes:
752            self.name = "int32"
753        elif "TENSOR_FLOAT16" in targetTypes:
754            self.name = "float16"
755        else:
756            self.name = "float32"
757        return self
758
759    def AutoIdentify(self, model):
760        if self.targetType is not None:
761            # By default, select all the float32 tensors/scalars
762            targets = {op: ["TENSOR_" + self.targetType.upper()] \
763                    for op in model.operands if op.type.type == "TENSOR_FLOAT32"}
764            targets.update({op: [self.targetType.upper()] \
765                    for op in model.operands if op.type.type == "FLOAT32"})
766            self.Identify(targets)
767        return self
768
769    def TransformOperand(self, op, arg=None):
770        if len(arg) == 1:
771            typeTuple = (arg[0], op.type.dimensions)
772        else:
773            typeTuple = (arg[0], op.type.dimensions, *arg[1:])
774        # To handle Internal operands
775        if op.value is None or op.type.GetNumberOfElements() == 0:
776            op.type = Type.GetType(*typeTuple)
777        else:
778            v = Dequantize(op.GetValueAsNumpy().astype(np.float32), op.type)
779            op.type = Type.GetType(*typeTuple)
780            v = Quantize(v, op.type)
781            op.SetValueFromNumpy(v)
782        return op
783
784# Convert model to turn on/off relaxed computation
785class RelaxedModeConverter(ModelVariation, ImplicitVariation):
786
787    def __init__(self, isRelaxed=True, name=None):
788        ModelVariation.__init__(self, name=name)
789        if isinstance(isRelaxed, bool):
790            self.isRelaxed = isRelaxed
791        else:
792            assert RelaxedModeConverter.IsCompatible(isRelaxed.lower())
793            self.isRelaxed = True
794
795    @staticmethod
796    def IsCompatible(value):
797        return value.lower() in ["relaxed"]
798
799    def SetToDefaultName(self):
800        self.name = "relaxed" if self.isRelaxed else "float"
801        return self
802
803    def TransformModel(self, model):
804        model.RelaxedExecution(self.isRelaxed)
805        return model
806
807# Convert data layout between "NHWC" amd "NCHW"
808class DataLayoutConverter(ModelVariation, ImplicitVariation):
809
810    def __init__(self, targetLayout="nchw", name=None):
811        ModelVariation.__init__(self, name=name)
812        self.targetLayout = targetLayout.lower()
813        assert DataLayoutConverter.IsCompatible(self.targetLayout)
814        self.perm = (0, 3, 1, 2) if self.targetLayout == "nchw" else (0, 2, 3, 1)
815        self.param = True if self.targetLayout == "nchw" else False
816
817    @staticmethod
818    def IsCompatible(value):
819        return value.lower() in ["nhwc", "nchw"]
820
821    def SetToDefaultName(self):
822        self.name = self.targetLayout
823        return self
824
825    def TransformOperand(self, op, arg=None):
826        if len(op.type.dimensions) == 4:
827            # To handle Internal operands
828            if op.value is not None and op.type.GetNumberOfElements() != 0:
829                op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm))
830            newDim = [op.type.dimensions[i] for i in self.perm]
831            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
832        elif len(op.type.dimensions) == 1 and len(op.value) == 4:
833            op.SetValueFromNumpy(op.GetValueAsNumpy()[list(self.perm)])
834        elif op.type.type == "BOOL":
835            op.SetValue(self.param)
836        else:
837            assert False, "%s not supported by DataLayoutConverter"%op
838        return op
839
840# Convert data by tansposing and removing axis
841class AxisConverter(ModelVariation):
842
843    def __init__(self, origin, target, dim, drop=[], name=None):
844        ModelVariation.__init__(self, name=name)
845        self.origin = origin
846        self.target = target
847        assert all(i >= -dim and i < dim for i in [self.origin, self.target])
848        self.dim = dim
849        self.perm = list(range(dim))
850        self.perm.insert(target if target >= 0 else target + dim, self.perm.pop(origin))
851        self.drop = [drop] if type(drop) is int else list(drop)
852        assert all(i >= -dim and i < dim for i in self.drop)
853        self.drop = [i if i >= 0 else i + dim for i in self.drop]
854        assert target not in self.drop and target + dim not in self.drop
855
856    def SetToDefaultName(self):
857        axis = self.target if self.target >= 0 else self.target + self.dim
858        axis -= sum(i < axis for i in self.drop)
859        neg = "" if self.target >= 0 else "_neg"
860        self.name = "dim%d_axis%d%s"%(self.dim - len(self.drop), axis, neg)
861        return self
862
863    def TransposeAxis(self, op):
864        if op.type.type == "INT32":
865            op.SetValue(self.target)
866        elif len(op.type.dimensions) == self.dim:
867            # To handle Internal operands
868            if op.value is not None:
869                op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm))
870            newDim = [op.type.dimensions[i] for i in self.perm]
871            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
872        else:
873            assert False, "%s not supported by AxisConverter"%op
874        return op
875
876    def RemoveAxis(self, op):
877        if op.type.type == "INT32":
878            if op.value[0] >= 0:
879                op.SetValue(op.value[0] - sum(i < op.value[0] for i in self.drop))
880            else:
881                op.SetValue(op.value[0] + sum(i > (op.value[0] + self.dim) for i in self.drop))
882        elif len(op.type.dimensions) == self.dim:
883            if op.value is not None:
884                val = op.GetValueAsNumpy()
885                for i in sorted(self.drop, reverse=True):
886                    val = np.take(val, 0, axis=i)
887                op.SetValueFromNumpy(val)
888            newDim = [op.type.dimensions[i] for i in range(self.dim) if i not in self.drop]
889            op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint)
890        else:
891            assert False, "%s not supported by AxisConverter"%op
892        return op
893
894    def TransformOperand(self, op, arg=None):
895        op = self.TransposeAxis(op)
896        op = self.RemoveAxis(op)
897        return op
898
899# Convert a Parameter to Input
900class ParameterAsInputConverter(ModelVariation, ImplicitVariation):
901
902    def __init__(self, arg="as_input", prefix="weight", name=None):
903        ModelVariation.__init__(self, name=name)
904        assert ParameterAsInputConverter.IsCompatible(arg.lower())
905        self.prefix = prefix
906
907    @staticmethod
908    def IsCompatible(value):
909        return value.lower() in ["as_input"]
910
911    def SetToDefaultName(self):
912        self.name = self.prefix + "_as_input"
913        return self
914
915    def TransformOperand(self, op, arg=None):
916        assert isinstance(op, Parameter), "%s cannot be converted to Input."%type(op)
917        newop = Input(op.name, op.type.GetSignatureTuple(), skipRenaming=True, extraParams=op.type.extraParams)
918        newop.SetValue(op.value)
919        return newop
920
921# Convert Output based on activation
922class ActivationConverter(ModelVariation, ImplicitVariation):
923    # (Enum, low, high)
924    actMap = {
925        "none": (0, None, None),
926        "relu": (1, 0.0, None),
927        "relu1": (2, -1.0, 1.0),
928        "relu6": (3, 0.0, 6.0),
929    }
930    def __init__(self, act="relu", name=None):
931        ModelVariation.__init__(self, name=name)
932        self.act = act.lower()
933        assert ActivationConverter.IsCompatible(self.act)
934        self.enum = ActivationConverter.actMap[self.act][0]
935        self.low = ActivationConverter.actMap[self.act][1]
936        self.high = ActivationConverter.actMap[self.act][2]
937
938    @staticmethod
939    def IsCompatible(value):
940        return value.lower() in ActivationConverter.actMap.keys()
941
942    def SetToDefaultName(self):
943        self.name = self.act
944        return self
945
946    def TransformOperand(self, op, arg=None):
947        if op.type.type == "INT32": # activation enum
948            return op.SetValue(self.enum)
949        else:
950            assert isinstance(op, Output)
951            v = op.GetValueAsNumpy()
952            if self.low is not None:
953                low = Quantize(self.low, op.type)
954                v = np.maximum(v, low)
955            if self.high is not None:
956                high = Quantize(self.high, op.type)
957                v = np.minimum(v, high)
958            return op.SetValueFromNumpy(v)
959
960class DynamicOutputShapeConverter(ModelVariation):
961    def __init__(self, name=None):
962        ModelVariation.__init__(self, name=name)
963
964    def SetToDefaultName(self):
965        self.name = "dynamic_output_shape"
966        return self
967
968    def TransformModel(self, model):
969        model.TestDynamicOutputShape(True)
970        return model
971
972# An example is always attached to a model, and could have multiple variations
973class Example:
974    examples = []
975    versionOverrides = {}
976
977    def __init__(self, *args, model=None, name=None):
978        self.model = Model.models[-1] if model is None else model
979        self.name = name
980        self.expectedMultinomialDistributionTolerance = None
981        self.feedDicts = []
982        for feedDict in args:
983            if type(feedDict) is tuple or type(feedDict) is list:
984                self.feedDicts.append(feedDict)
985            elif type(feedDict) is dict:
986                self.feedDicts.append((
987                    {i: feedDict[i] for i in self.model.GetInputs()},
988                    {o: feedDict[o] for o in self.model.GetOutputs()}
989                ))
990            else:
991                assert False
992        if Configuration.test_dynamic_output_shape:
993            self.variations = [[DefaultVariation(), DynamicOutputShapeConverter()]]
994        else:
995            self.variations = []
996        Example.examples.append(self)
997
998    @staticmethod
999    def SetVersion(ver, *args):
1000        for name in args:
1001            Example.versionOverrides[name] = ver
1002
1003    # Main entrance of test generator
1004    @staticmethod
1005    def DumpAllExamples(DumpModel=None, model_fd=None,
1006                        DumpExample=None, example_fd=None,
1007                        DumpTest=None, test_fd=None):
1008        Example.CombineAllExamples()
1009        for example in Example.examples:
1010            example.Dump(DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd)
1011
1012    # Combine examples with the same model, same name, and same set of variations
1013    @staticmethod
1014    def CombineAllExamples():
1015        modelMap = {}
1016        newExamples = []
1017        for example in Example.examples:
1018            key = (example.model, example.name, tuple(tuple(e) for e in example.variations))
1019            if key in modelMap:
1020                modelMap[key].Combine(example)
1021            else:
1022                modelMap[key] = example
1023                newExamples.append(example)
1024        Example.examples = newExamples
1025
1026    def AddVariations(self, *args, includeDefault=True, defaultName=None):
1027        self.variations.append([DefaultVariation(defaultName)] if includeDefault else [])
1028        self.variations[-1].extend(ImplicitVariation.ImplicitConvertion(i) for i in args)
1029        return self
1030
1031    def AddNchw(self, *args, includeDefault=True, defaultName="nhwc"):
1032        var = DataLayoutConverter("nchw").Identify(args)
1033        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1034        return self
1035
1036    def AddRelaxed(self, isRelaxed=True, includeDefault=True, defaultName=None):
1037        var = RelaxedModeConverter(isRelaxed)
1038        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1039        return self
1040
1041    def AddInput(self, *args, includeDefault=True, defaultName=None):
1042        var = ParameterAsInputConverter().Identify(args)
1043        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1044        return self
1045
1046    def AddRelu(self, *args, includeDefault=True, defaultName=None):
1047        var = ActivationConverter("relu").Identify(args)
1048        self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName)
1049        return self
1050
1051    def AddAllActivations(self, *args):
1052        var = [ActivationConverter(i).Identify(args)
1053            for i in sorted(ActivationConverter.actMap.keys())]
1054        self.AddVariations(*var, includeDefault=False)
1055        return self
1056
1057    def GuessOriginalAxisAndDim(self, *args):
1058        origin = None
1059        dim = None
1060        for arg in args:
1061            if arg.type.type == "INT32":
1062                origin = arg.value[0]
1063            else:
1064                if dim is None:
1065                    dim = len(arg.type.dimensions)
1066                else:
1067                    assert dim == len(arg.type.dimensions)
1068        assert dim is not None
1069        origin = dim - 1 if origin is None else origin
1070        origin = origin + dim if origin < 0 else origin
1071        return origin, dim
1072
1073    def AddAxis(self, axis, *args, includeDefault=True, defaultName=None):
1074        origin, dim = self.GuessOriginalAxisAndDim(*args)
1075        axis = [axis] if type(axis) is int else list(axis)
1076        var = [AxisConverter(origin, a, dim).Identify(args) for a in axis]
1077        self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName)
1078        return self
1079
1080    def AddAllPositiveAxis(self, *args):
1081        origin, dim = self.GuessOriginalAxisAndDim(*args)
1082        var = [AxisConverter(origin, a, dim).Identify(args) for a in range(dim)]
1083        self.AddVariations(*var, includeDefault=False)
1084        return self
1085
1086    def AddAllAxis(self, *args):
1087        origin, dim = self.GuessOriginalAxisAndDim(*args)
1088        var = [AxisConverter(origin, a, dim).Identify(args) for a in range(-dim, dim)]
1089        self.AddVariations(*var, includeDefault=False)
1090        return self
1091
1092    def AddDims(self, dims, *args, includeDefault=True, defaultName=None):
1093        origin, dim = self.GuessOriginalAxisAndDim(*args)
1094        dims = [dims] if type(dims) is int else list(dims)
1095        drop = list(range(dim))
1096        drop.pop(origin)
1097        var = [AxisConverter(origin, origin, dim, drop[0:(dim-i)]).Identify(args) for i in dims]
1098        self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName)
1099        return self
1100
1101    def AddAllDims(self, *args):
1102        origin, dim = self.GuessOriginalAxisAndDim(*args)
1103        drop = list(range(dim))
1104        drop.pop(origin)
1105        var = [AxisConverter(origin, origin, dim, drop[0:i]).Identify(args) for i in range(dim)]
1106        self.AddVariations(*var, includeDefault=False)
1107        return self
1108
1109    def AddAllDimsAndPositiveAxis(self, *args):
1110        origin, dim = self.GuessOriginalAxisAndDim(*args)
1111        var = [AxisConverter(origin, j, dim, range(i)).Identify(args) \
1112                for i in range(dim) for j in range(i, dim)]
1113        self.AddVariations(*var, includeDefault=False)
1114        return self
1115
1116    def AddAllDimsAndAxis(self, *args):
1117        origin, dim = self.GuessOriginalAxisAndDim(*args)
1118        var = [AxisConverter(origin, k, dim, range(i)).Identify(args) \
1119                for i in range(dim) for j in range(i, dim) for k in [j, j - dim]]
1120        self.AddVariations(*var, includeDefault=False)
1121        return self
1122
1123    def Combine(self, other):
1124        assert self.model is other.model, "Only examples targetting the same model can be combined"
1125        assert tuple(self.variations) == tuple(other.variations), \
1126            "Only examples with the same set of variations can be combined"
1127        assert self.name == other.name, "Only examples with the same name can be combined"
1128        self.feedDicts.extend(other.feedDicts)
1129        return self
1130
1131    def Dump(self, DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd):
1132        [v.SetToDefaultName() for vs in self.variations for v in vs if v.name is None]
1133        for variationList in itertools.product(*self.variations):
1134            # Apply variations
1135            modelOrigin, feedDictsOrigin = self.model, self.feedDicts
1136            self.model, self.feedDicts = copy.deepcopy((self.model, self.feedDicts))
1137            for variation in variationList:
1138                self.model, self.feedDicts = variation.ApplyTo(self.model, self.feedDicts)
1139            # Concat names for test and examples
1140            varNames = [v.name for v in variationList]
1141            self.testName = NamedTest(FileNames.specName, self.model.name, self.name, *varNames)
1142            self.examplesName = GlobalVariable("examples", self.model.name, self.name, *varNames)
1143            if str(self.testName) in Example.versionOverrides:
1144                self.model.IntroducedIn(Example.versionOverrides[str(self.testName)])
1145            self.model.WithSuffix(*varNames).Compile()
1146            # Dump files
1147            if DumpModel is not None and model_fd is not None:
1148                DumpModel(self.model, model_fd)
1149            if DumpExample is not None and example_fd is not None:
1150                DumpExample(self, example_fd)
1151            if DumpTest is not None and test_fd is not None:
1152                DumpTest(self, test_fd)
1153            # Restore model and feedDicts before variation
1154            self.model = modelOrigin
1155            self.feedDicts = feedDictsOrigin
1156        return self
1157
1158    # Specifies the RANDOM_MULTINOMIAL distribution tolerance.
1159    # If set to greater than zero, the input is compared as log-probabilities
1160    # to the output and must be within this tolerance to pass.
1161    def WithMultinomialDistributionTolerance(self, expectedTolerance):
1162      self.expectedMultinomialDistributionTolerance = expectedTolerance
1163      return self
1164
1165    # For backward-compatibility with slicing.py
1166    # Similar to dump_dict, but in python. Used by the slicing tool
1167    # if referenced is not None, only print operands that are present there
1168    @staticmethod
1169    def py_dump_dict(d, referenced):
1170        ret = []
1171        for k, v in d.items():
1172            if referenced != None and k not in referenced:
1173                continue
1174            key = str(k)
1175            init = pprint.pformat(v)
1176            ret.append("%s: %s" % (key, init))
1177        return ", ".join(ret)
1178
1179    # For backward-compatibility with slicing.py
1180    # similar to dump, but in python. Used by the slicing tool
1181    # if referenced is not None, only print operands that are present there
1182    @staticmethod
1183    def py_dump(example_file, override, referenced):
1184        Example.CombineAllExamples()
1185        if len(Example.examples[0].feedDicts) > 0:
1186            example_no = 0
1187            example_template = """\
1188input{no} = {{{inputs}}}
1189# Only executed during data collection phase
1190if collecting_data is True:
1191  Example((input{no}, {{{outputs}}}))
1192"""
1193        for i, o in Example.examples[0].feedDicts:
1194            print ('# Begin of an example', file = example_file)
1195            inputs = Example.py_dump_dict(i, referenced)
1196            output_list = []
1197            for k, v in override.items():
1198                output_list.append("%s: [0] * %d" % (k, v))
1199            outputs = ",".join(output_list)
1200
1201            # TODO: handle >1 outputs
1202            for k, v in o.items():
1203                assert k.index == 0
1204            example_contents = {
1205                'no': example_no,
1206                'inputs': inputs,
1207                'outputs': outputs
1208            }
1209            print (example_template.format(**example_contents), file = example_file)
1210
1211class FileNames:
1212    specFiles = []
1213    specNames = []
1214    modelFiles = []
1215    exampleFiles = []
1216    testFiles = []
1217    specFile = ""
1218    specName = ""
1219    modelFile = ""
1220    exampleFile = ""
1221    testFile = ""
1222    ctsFile = ""
1223    logFile = ""
1224    version = ""
1225    fileIndex = 0
1226
1227    @staticmethod
1228    def InitializeFileLists(spec, model, example, test, cts="-", log=""):
1229        # get all spec files and target files
1230        if os.path.isfile(spec):
1231            FileNames.specFiles = [os.path.abspath(spec)]
1232        elif os.path.isdir(spec):
1233            FileNames.specFiles = sorted([os.path.abspath(os.path.join(spec, f))
1234                for f in os.listdir(spec) if f.endswith(".mod.py")])
1235        else:
1236            assert False, "%s is neither a file or a directory"%spec
1237        FileNames.specNames = [re.sub(r"\..*", "", os.path.basename(f))
1238            for f in FileNames.specFiles]
1239        FileNames.modelFiles = FileNames.ParseTargetFiles(model, ".model.cpp")
1240        FileNames.exampleFiles = FileNames.ParseTargetFiles(example, ".example.cpp")
1241        FileNames.testFiles = FileNames.ParseTargetFiles(test, ".mod.py.cpp")
1242        FileNames.ctsFile = os.path.abspath(cts) if cts != "-" else "-"
1243        FileNames.logFile = ", \"%s\""%log if log != "" else ""
1244
1245    @staticmethod
1246    def ParseTargetFiles(arg, ext):
1247        numFiles = len(FileNames.specFiles)
1248        absPath = os.path.abspath(arg)
1249        if os.path.isdir(arg):
1250            target = [os.path.join(absPath, f + ext) for f in FileNames.specNames]
1251        elif arg == "-":
1252            target = ["-"] * numFiles
1253        else:
1254            target = [absPath] * numFiles
1255        return target
1256
1257    @staticmethod
1258    def NextFile():
1259        if FileNames.fileIndex >= len(FileNames.specFiles):
1260            return False
1261        FileNames.specFile = FileNames.specFiles[FileNames.fileIndex]
1262        FileNames.specName = FileNames.specNames[FileNames.fileIndex]
1263        FileNames.modelFile = FileNames.modelFiles[FileNames.fileIndex]
1264        FileNames.exampleFile = FileNames.exampleFiles[FileNames.fileIndex]
1265        FileNames.testFile = FileNames.testFiles[FileNames.fileIndex]
1266        FileNames.fileIndex += 1
1267        NamedObject.existingNames = set()
1268        NamedVariable.existingNames = set()
1269        NamedTest.existingNames = set()
1270        Type.typesMap = dict()
1271        Model.models = list()
1272        Example.examples = list()
1273        Configuration.use_shm_for_weights = False
1274
1275        # Extract version from absolute file path.
1276        versionMatch = re.findall(r"/V\d_\d/", FileNames.specFile)
1277        if len(versionMatch) == 1:
1278            FileNames.version = versionMatch[0].strip('/')
1279        else:
1280            FileNames.version = None
1281        return True
1282
1283class Configuration:
1284    use_shm_for_weights = False
1285    force_regenerate = False
1286    test_dynamic_output_shape = True
1287
1288    @staticmethod
1289    def useSHM():
1290        return Configuration.use_shm_for_weights
1291