1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""NN model compiler 18 19Contain classes definition and utilify functions for compiling models and 20examples into NDK-based CTS and VTS unit tests. 21 22Used by cts_generator.py, vts_generator.py, and slicing.py 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28import argparse 29import copy 30from functools import reduce 31import itertools 32import math 33import os 34import re 35import struct 36import sys 37import contextlib 38import pprint 39import numpy as np 40 41def GetJointStr(l, sep=", ", method=str): 42 return sep.join([method(i) for i in l]) 43 44# Print in C float literal format 45def PrettyPrintAsFloat(x): 46 s = str(float(x)) 47 if s.find(".") >= 0 or s.find("e") >= 0: 48 return s + "f" 49 else: 50 return s + ".0f" 51 52# Transform from original type to float32 53def Dequantize(v, ty): 54 v -= ty.zeroPoint 55 if ty.scale != 0: 56 v *= ty.scale 57 if isinstance(ty.extraParams, SymmPerChannelQuantParams): 58 v *= ty.extraParams.GetScalesBroadcastArray(ty.dimensions) 59 return v 60 61# Transform float32 to target data type 62def Quantize(v, ty): 63 if ty.scale != 0: 64 v /= ty.scale 65 if isinstance(ty.extraParams, SymmPerChannelQuantParams): 66 v = v / ty.extraParams.GetScalesBroadcastArray(ty.dimensions) 67 v += ty.zeroPoint 68 if not ty.IsFloat(): 69 v = np.round(v) 70 v = int(v) if np.isscalar(v) else v.astype(int) 71 if ty.type == "TENSOR_QUANT8_ASYMM": 72 v = np.minimum(np.maximum(v, 0), 255) 73 elif ty.type == "TENSOR_QUANT16_ASYMM": 74 v = np.minimum(np.maximum(v, 0), 65535) 75 elif ty.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL": 76 v = np.minimum(np.maximum(v, -127), 127) 77 elif ty.type == "UINT32": 78 v = np.maximum(v, 0) 79 return v 80 81@contextlib.contextmanager 82def SmartOpen(filename=None, mode="w"): 83 if filename and filename != '-': 84 fh = open(filename, mode) 85 else: 86 fh = sys.stdout 87 88 try: 89 yield fh 90 finally: 91 if fh is not sys.stdout: 92 fh.close() 93 94# Tracking objects inside a model with a unique name 95class NamedObject: 96 existingNames = set() 97 98 def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False): 99 name = GetJointStr([i for i in args if i is not None and i != ""], sep=sep) 100 if skipRenaming: 101 self.name = name 102 return 103 # make the name unique by renaming with a suffix number 104 uniqueName = name if showZero is False else name + sep + str(startsFrom) 105 while uniqueName in self.__class__.existingNames: 106 startsFrom += 1 107 uniqueName = name + sep + str(startsFrom) 108 self.__class__.existingNames.add(uniqueName) 109 self.name = uniqueName 110 111 def __str__(self): 112 return self.name 113 __repr__ = __str__ 114 115 # Since names are unique, objects with the same name are considered equal 116 def __eq__(self, other): 117 return isinstance(other, NamedObject) and self.name == other.name 118 119 def __ne__(self, other): 120 return not self.__eq__(other) 121 122 def __hash__(self): 123 return hash(self.name) 124 125 def __lt__(self, other): 126 return self.name < other.name 127 128# Types, operands should all have a unique name since they share the same namespace 129class NamedVariable(NamedObject): 130 existingNames = set() 131 def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False): 132 NamedObject.__init__(self, *args, sep=sep, showZero=showZero, 133 startsFrom=startsFrom, skipRenaming=skipRenaming) 134 135# Global variables in the spec namespace such as CreateModel, is_ignored, and examples 136class GlobalVariable(NamedVariable): 137 def __init__(self, *args, skipRenaming=False): 138 NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming) 139 140# Each test should have a unique name, but will not conflict with variables 141class NamedTest(NamedObject): 142 existingNames = set() 143 def __init__(self, *args, startsFrom=0, skipRenaming=False): 144 NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming) 145 146class Type(NamedVariable): 147 typesMap = dict() 148 typeLookup = { 149 "INT32": "int32_t", 150 "UINT32": "uint32_t", 151 "FLOAT32": "float", 152 "FLOAT16": "_Float16", 153 "TENSOR_INT32": "int32_t", 154 "TENSOR_FLOAT16": "_Float16", 155 "TENSOR_FLOAT32": "float", 156 "TENSOR_QUANT8_ASYMM": "uint8_t", 157 "TENSOR_QUANT8_SYMM": "int8_t", 158 "BOOL": "bool8", 159 "TENSOR_QUANT16_ASYMM": "uint16_t", 160 "TENSOR_QUANT16_SYMM": "int16_t", 161 "TENSOR_BOOL8": "bool8", 162 "TENSOR_QUANT8_SYMM_PER_CHANNEL": "int8_t", 163# "OEM_SCALAR": this is service-defined. 164 "TENSOR_OEM_BYTE": "uint8_t", 165 } 166 167 # types are named as "type0", "type1", ... 168 def __init__(self, vt, dimensions, scale, zeroPoint, name="type", skipRenaming=False, 169 extraParams=None): 170 NamedVariable.__init__(self, name, sep="", showZero=True, skipRenaming=skipRenaming) 171 self.type = vt 172 self.dimensions = dimensions 173 self.scale = float(scale) 174 self.zeroPoint = int(zeroPoint) 175 self.extraParams = extraParams 176 177 # Factory for Type object, only create a new Type if requested type does 178 # not have a match with all existing types 179 @staticmethod 180 def GetType(vt, dimensions, scale=0, zeroPoint=0, extraParams=None): 181 key = ",".join([vt, str(dimensions), str(scale), str(zeroPoint), str(extraParams)]) 182 if key not in Type.typesMap: 183 Type.typesMap[key] = Type(vt, dimensions, scale, zeroPoint, extraParams=extraParams) 184 return Type.typesMap[key] 185 186 @staticmethod 187 def GetAllTypes(): 188 # sort to ensure a stable order when dumping the code 189 return sorted(Type.typesMap.values()) 190 191 # For backward-compatibility 192 @staticmethod 193 def GetTypeFromString(vt, shape, extraParams=None): 194 dimensions, scale, zeroPoint = Type.GetParsedShape(shape) 195 scale = float(scale) 196 zeroPoint = int(zeroPoint) 197 return Type.GetType(vt, dimensions, scale, zeroPoint, extraParams) 198 199 # For backward-compatibility 200 @staticmethod 201 def GetParsedShape(shape): 202 # Parse shape 203 if (shape != "" and shape != "{}"): 204 left, sep, right = shape.partition('{') 205 real_shape, sep, right = right.partition('}') 206 shape = [int(x) for x in real_shape.split(",")] 207 # left now looks like "0.0f, 127.5f, " 208 scale, sep, zero_point = right.rpartition(',') 209 if scale == "": 210 if zero_point == "": 211 return shape, "0", "0" 212 return shape, zero_point, "0" 213 left, sep, scale = scale.partition(',') 214 return shape, scale.replace("f", ""), zero_point 215 else: 216 return [], "0", "0" 217 218 def GetNumberOfElements(self): 219 return reduce(lambda x,y: x*y, self.dimensions, 1) 220 221 def GetCppTypeString(self): 222 return Type.typeLookup[self.type] 223 224 def IsFloat(self): 225 return self.GetCppTypeString() in ["float", "_Float16"] 226 227 def IsBool(self): 228 return self.GetCppTypeString() == "bool8" 229 230 def GetElementByteSize(self): 231 cppTypeString = self.GetCppTypeString() 232 if cppTypeString in ["uint8_t", "int8_t", "bool8"]: 233 return 1 234 elif cppTypeString in ["int16_t", "uint16_t", "_Float16"]: 235 return 2 236 else: 237 return 4 238 239 def GetByteSize(self): 240 return self.GetElementByteSize() * self.GetNumberOfElements() 241 242 def GetDimensionsString(self): 243 return "{" + GetJointStr(self.dimensions) + "}" 244 245 def GetSignatureTuple(self): 246 return (self.type, self.dimensions, self.scale, self.zeroPoint) 247 248 # For backward-compatibility with slicing.py 249 def GetRawShape(self): 250 if self.scale == 0 and self.zeroPoint == 0: 251 return self.GetDimensionsString() 252 else: 253 return GetJointStr([self.GetDimensionsString(), self.scale, self.zeroPoint]) 254 255 def ToUnspecifiedDim(self): 256 return Type.GetType(self.type, [0] * len(self.dimensions), self.scale, self.zeroPoint) 257 258# To track implicitly convertible parameter types 259class ImplicitParameter(): 260 @staticmethod 261 def ImplicitConvertion(value): 262 if isinstance(value, Operand): 263 return value 264 for implicitType in ImplicitParameter.__subclasses__(): 265 if implicitType.IsCompatible(value): 266 return implicitType("param", value) 267 assert False, "%s not supported for implicit parameter"%value 268 269 270# ExtraParams with per-channel quantization. 271class SymmPerChannelQuantParams(): 272 def __init__(self, channelDim, scales, hide = False): 273 self.channelDim = channelDim 274 self.scales = scales 275 self.hide = hide 276 277 def GetScalesBroadcastArray(self, dimensions): 278 bshape = [1] * len(dimensions) 279 bshape[self.channelDim] = len(self.scales) 280 return np.array(self.scales).reshape(bshape) 281 282 def GetConstructor(self): 283 return "SymmPerChannelQuantParams({%s},%d)" % ( 284 ", ".join(str(x) + "f" for x in self.scales), self.channelDim) 285 286 def GetVtsSetter(self): 287 return "channelQuant" 288 289 def GetVtsConstructor(self): 290 return "SymmPerChannelQuantParams{.scales={%s}, .channelDim=%d}" % ( 291 ", ".join(str(x) + "f" for x in self.scales), self.channelDim) 292 293 294# An operand that can be fed into operations. Also, an operand is always 295# declared before operations. 296class Operand(NamedVariable): 297 298 def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None): 299 NamedVariable.__init__(self, name, sep="", skipRenaming=skipRenaming) 300 if type(opType) is str: 301 self.type = Type.GetTypeFromString(opType, value, extraParams) 302 value = backward 303 else: 304 self.type = Type.GetType(*opType, extraParams=extraParams) 305 self.SetValue(value) 306 self.dimensions = self.type.dimensions 307 self.lifetime = "TEMPORARY_VARIABLE" 308 self.ins = [] 309 self.outs = [] 310 311 def SetValue(self, value): 312 self.value = value if type(value) is list or type(value) is tuple else [value] 313 return self 314 315 def SetValueFromNumpy(self, value): 316 self.value = value.flatten().tolist() 317 return self 318 319 def GetValueAsNumpy(self): 320 return np.array(self.value).reshape(self.type.dimensions) 321 322 # Print value as cpp-style list initialization 323 def GetListInitialization(self): 324 assert self.value is not None, \ 325 "Trying to print operand %s with None value"%(str(self)) 326 if self.type.IsFloat(): 327 return "{%s}"%(GetJointStr(self.value, method=PrettyPrintAsFloat)) 328 elif self.type.IsBool(): 329 return "{%s}"%(GetJointStr(self.value, method=lambda v: "true" if v else "false")) 330 else: 331 return "{%s}"%(GetJointStr(self.value, method=lambda x: str(int(x)))) 332 333 def ToUnspecifiedDim(self): 334 self.dimensions = self.type.dimensions 335 self.type = self.type.ToUnspecifiedDim() 336 337# Base class of user-defined input/output operand 338class InOut(Operand): 339 340 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 341 Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, extraParams=extraParams) 342 self.lifetime = "MODEL_INPUT" 343 self.index = 0 344 345 def Feed(self, value): 346 self.SetValue(value[self] if type(value) is dict else value) 347 return self 348 349 def GetListInitialization(self): 350 return "{%d, %s}"%(self.index, super().GetListInitialization()) 351 352# A user-declared input operand 353class Input(InOut): 354 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 355 InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 356 self.lifetime = "MODEL_INPUT" 357 358# A user-declared output operand 359class Output(InOut): 360 def __init__(self, name, opType, backward=None, skipRenaming=False): 361 InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming) 362 self.lifetime = "MODEL_OUTPUT" 363 364# An output that we don't want to compare the results 365class IgnoredOutput(Output): 366 def __init__(self, name, opType, backward=None, skipRenaming=False): 367 Output.__init__(self, name, opType, backward, skipRenaming=skipRenaming) 368 self.lifetime = "MODEL_OUTPUT" 369 def Feed(self, value): 370 numElements = reduce(lambda x,y: x*y, self.dimensions, 1) 371 self.value = [0 for x in range(numElements)] 372 return self 373 374# An explicitly declared parameter 375class Parameter(Operand): 376 def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None): 377 Operand.__init__(self, name, opType, value, backward, skipRenaming=skipRenaming, 378 extraParams=extraParams) 379 self.initializer = NamedVariable(str(self) + "_init") 380 self.lifetime = "CONSTANT_REFERENCE" if Configuration.useSHM() else "CONSTANT_COPY" 381 382# A shortcut for parameters of INT32 383class Int32Scalar(Parameter, ImplicitParameter): 384 def __init__(self, name, value): 385 Parameter.__init__(self, name, ("INT32", []), int(value)) 386 @staticmethod 387 def IsCompatible(value): 388 return type(value) is int 389 390# A shortcut for parameters of FLOAT16 391class Float16Scalar(Parameter, ImplicitParameter): 392 def __init__(self, name, value): 393 Parameter.__init__(self, name, ("FLOAT16", []), float(value)) 394 @staticmethod 395 def IsCompatible(value): 396 return False 397 398# A shortcut for parameters of FLOAT32 399class Float32Scalar(Parameter, ImplicitParameter): 400 def __init__(self, name, value): 401 Parameter.__init__(self, name, ("FLOAT32", []), float(value)) 402 @staticmethod 403 def IsCompatible(value): 404 return type(value) is float 405 406# A shortcut for parameters of BOOL 407class BoolScalar(Parameter, ImplicitParameter): 408 def __init__(self, name, value): 409 Parameter.__init__(self, name, ("BOOL", []), bool(value)) 410 @staticmethod 411 def IsCompatible(value): 412 return type(value) is bool 413 414# A shortcut for parameter of 1-D TENSOR_INT32 415class Int32Vector(Parameter, ImplicitParameter): 416 def __init__(self, name, value): 417 Parameter.__init__(self, name, ("TENSOR_INT32", [len(value)]), [int(v) for v in value]) 418 @staticmethod 419 def IsCompatible(value): 420 if type(value) is not list and type(value) is not tuple: 421 return False 422 return all(type(i) is int for i in value) 423 424# A shortcut for parameter of 1-D TENSOR_FLOAT32 425class Float32Vector(Parameter, ImplicitParameter): 426 def __init__(self, name, value): 427 Parameter.__init__(self, name, ("TENSOR_FLOAT32", [len(value)]), [float(v) for v in value]) 428 @staticmethod 429 def IsCompatible(value): 430 if type(value) is not list and type(value) is not tuple: 431 return False 432 return all(type(i) is float for i in value) 433 434# An explicitly declared intermediate result 435class Internal(Operand): 436 def __init__(self, name, opType, backward=None, skipRenaming=False): 437 Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming) 438 self.lifetime = "TEMPORARY_VARIABLE" 439 440# An operation in a model, does not need a name 441class Operation: 442 443 def __init__(self, optype, ins, outs): 444 self.optype = optype 445 self.SetInputs(ins) 446 self.SetOutputs(outs) 447 448 # for the ease of debugging 449 def __str__(self): 450 insString = GetJointStr(self.ins) 451 outsString = GetJointStr(self.outs) 452 return "Operation %s: [%s] -> [%s]"%(self.optype, insString, outsString) 453 __repr__ = __str__ 454 455 def SetInputs(self, ins): 456 self.ins = [ImplicitParameter.ImplicitConvertion(i) for i in ins] 457 return self 458 459 def SetOutputs(self, outs): 460 self.outs = list(outs) 461 return self 462 463 # For backward-compatibility with slicing.py 464 # Get Python-ish dump for the op 465 def PyDefinition(self): 466 py_op_string = """Operation("{optype}", {inputs}).To({outputs})""" 467 inputs = [str(x) for x in self.ins] 468 inputs = ", ".join(inputs) 469 assert len(self.outs) <= 1 470 outputs = str(self.outs[0]) 471 ops = {"optype": self.optype, "inputs": inputs, "outputs": outputs} 472 return py_op_string.format(**ops) 473 474# Main interface 475class Model: 476 models = list() 477 478 def __init__(self, name=None): 479 self.name = name 480 self.operations = [] 481 self.operands = [] 482 self.isRelaxed = False 483 self.compiled = False 484 self.dumped = False 485 self.hasDynamicOutputShape = False 486 self.version = FileNames.version 487 Model.models.append(self) 488 489 def WithSuffix(self, *args): 490 self.createFunctionName = GlobalVariable("CreateModel", self.name, *args) 491 self.createTestFunctionName = GlobalVariable("createTestModel", self.name, *args) 492 self.isIgnoredFunctionName = GlobalVariable("is_ignored", self.name, *args) 493 return self 494 495 def AddOperation(self, operation): 496 self.operations.append(operation) 497 for i in operation.ins: 498 if i not in self.operands: 499 self.operands.append(i) 500 for o in operation.outs: 501 if o not in self.operands: 502 self.operands.append(o) 503 return self 504 505 def Operation(self, op_name, *args): 506 return self.AddOperation(Operation(op_name, args, [])) 507 508 def To(self, *args): 509 assert len(self.operations) > 0 510 if type(args[0]) is tuple or type(args[0]) is list: 511 outs = args[0] 512 else: 513 outs = args 514 self.operations[-1].SetOutputs(outs) 515 for o in outs: 516 if o not in self.operands: 517 self.operands.append(o) 518 return self 519 520 def RelaxedExecution(self, isRelaxed): 521 self.isRelaxed = isRelaxed 522 return self 523 524 def TestDynamicOutputShape(self, hasDynamicOutputShape): 525 self.hasDynamicOutputShape = hasDynamicOutputShape 526 return self 527 528 # Sets the version of the model in compliance tests. Set to None to disable the test. 529 def IntroducedIn(self, ver): 530 self.version = ver 531 return self 532 533 def GetTypes(self): 534 return sorted(list(set(op.type for op in self.operands))) 535 536 def GetInputs(self): 537 return [i for i in self.operands if isinstance(i, Input)] 538 539 def GetOutputs(self): 540 return [o for o in self.operands if isinstance(o, Output)] 541 542 def GetInputsIndex(self): 543 return [i for i,op in enumerate(self.operands) if isinstance(op, Input)] 544 545 def GetOutputsIndex(self): 546 return [o for o,op in enumerate(self.operands) if isinstance(op, Output)] 547 548 def GetIndexOfOperands(self, operands): 549 return [self.operands.index(i) for i in operands] 550 551 def GetIgnoredOutputs(self): 552 return [o for o in self.operands if isinstance(o, IgnoredOutput)] 553 554 def GetParameters(self): 555 return [p for p in self.operands if isinstance(p, Parameter)] 556 557 def GetEquivalentOperands(self, targets): 558 return [self.operands[self.operands.index(t)] for t in targets] 559 560 def UpdateEquivalentOperands(self, targets): 561 for t in targets: 562 self.operands[self.operands.index(t)] = t 563 return self 564 565 def SetInputAndOutputIndex(self): 566 for ind, i in enumerate(self.GetInputs()): 567 i.index = ind 568 for ind, o in enumerate(self.GetOutputs()): 569 o.index = ind 570 return self 571 572 def SetOperandInsAndOuts(self): 573 for op in self.operands: 574 op.ins = list() 575 op.outs = list() 576 for op in self.operations: 577 op.ins = self.GetEquivalentOperands(op.ins) 578 op.outs = self.GetEquivalentOperands(op.outs) 579 for i in op.ins: 580 i.outs.append(op) 581 for o in op.outs: 582 o.ins.append(op) 583 return self 584 585 def TopologicalSortHelper(self, op, deps, visited): 586 if op in visited: 587 assert op not in deps, "Cycle detected in the graph" 588 else: 589 visited.add(op) 590 for i in deps[op]: 591 self.TopologicalSortHelper(i, deps, visited) 592 self.operations.append(op) 593 deps.pop(op) 594 595 # Topological sort of the operations, and detect if there is a cycle is the graph 596 def TopologicalSort(self): 597 deps = {op: list() for op in self.operations} 598 [deps[o].append(i) for op in self.operands for o in op.outs for i in op.ins] 599 operations = self.operations.copy() 600 self.operations = [] 601 visited = set() 602 for op in operations: 603 self.TopologicalSortHelper(op, deps, visited) 604 605 def SetOutputUnspecified(self): 606 for op in self.operands: 607 op.dimensions = op.type.dimensions 608 if self.hasDynamicOutputShape: 609 for op in self.GetOutputs(): 610 op.ToUnspecifiedDim() 611 return self 612 613 def Compile(self): 614 if self.compiled: 615 return self 616 self.SetInputAndOutputIndex() 617 self.SetOperandInsAndOuts() 618 self.TopologicalSort() 619 self.SetOutputUnspecified() 620 # Do not check compliance for relaxed mode and dynamic output shape tests. 621 if self.isRelaxed or self.hasDynamicOutputShape: 622 self.IntroducedIn(None) 623 self.compiled = True 624 return self 625 626# To track implicitly convertible variation types 627class ImplicitVariation: 628 @staticmethod 629 def ImplicitConvertion(value): 630 if isinstance(value, ModelVariation): 631 return value 632 for implicitType in ImplicitVariation.__subclasses__(): 633 value = value if type(value) is tuple or type(value) is list else [value] 634 if implicitType.IsCompatible(value[0]): 635 var = implicitType(value[0]) 636 if len(value) > 1: 637 var.Identify(*value[1:]) 638 return var 639 assert False, "%s not supported for implicit variation"%value[0] 640 641# The base class for model variations 642class ModelVariation: 643 644 def __init__(self, name=None): 645 self.targetOperands = {} 646 self.name = name 647 648 def ApplyToHelper(self, model, args, feedDicts, transform): 649 opVarList = [] 650 for op in model.GetEquivalentOperands(sorted(args.keys())): 651 opVar = op 652 feedDictsVar = [] 653 if isinstance(op, Input) or isinstance(op, Output): 654 for feedDict in feedDicts: 655 op_tmp = copy.deepcopy(op) 656 if op_tmp in feedDict[0]: 657 opVar = transform(op_tmp.Feed(feedDict[0]), args[op_tmp]) 658 elif op_tmp in feedDict[1]: 659 opVar = transform(op_tmp.Feed(feedDict[1]), args[op_tmp]) 660 else: 661 assert False 662 feedDictsVar.append(opVar.value) 663 assert type(op) == type(opVar), "Can not handle %s -> %s"%(type(op), type(opVar)) 664 else: 665 opVar = transform(op, args[op]) 666 # handle Parameter -> Input 667 if isinstance(opVar, Input) or isinstance(opVar, Output): 668 feedDictsVar = [opVar.value] * len(feedDicts) 669 if isinstance(opVar, Input) or isinstance(opVar, Output): 670 for feedDict, feedDictVar in zip(feedDicts, feedDictsVar): 671 if opVar in feedDict[1]: 672 feedDict[1][opVar] = feedDictVar 673 else: 674 feedDict[0][opVar] = feedDictVar 675 opVarList.append(opVar) 676 return opVarList 677 678 # Make a deepcopy of the model and feedDicts, and apply the change 679 def ApplyTo(self, modelOrigin, feedDictsOrigin): 680 model, feedDicts = copy.deepcopy((modelOrigin, feedDictsOrigin)) 681 model.compiled = False 682 model.dumped = False 683 684 if not self.targetOperands: 685 self.AutoIdentify(model) 686 687 # get transformed operands and update feedDicts 688 operandsVar = self.ApplyToHelper( 689 model, self.targetOperands, feedDicts, self.TransformOperand) 690 691 model = self.TransformModel(model) 692 model.UpdateEquivalentOperands(operandsVar) 693 return model, feedDicts 694 695 def IdentifyOperands(self, args=None): 696 if args is None: 697 return self 698 self.targetOperands = args if type(args) is dict else {i: None for i in args} 699 return self 700 701 def Identify(self, operandArgs=None, paramArgs=None): 702 self.IdentifyOperands(operandArgs) 703 return self 704 705 # Set variation to its default name 706 def SetToDefaultName(self): 707 self.name = "" 708 return self 709 710 # Automatically select the target operand list 711 def AutoIdentify(self, model): 712 return self 713 714 # Transform operands that are marked by IdentifyOperands() 715 def TransformOperand(self, op, arg=None): 716 return op 717 718 # Transform the model 719 def TransformModel(self, model): 720 return model 721 722# Default variation that does nothing 723class DefaultVariation(ModelVariation): 724 725 def __init__(self, name=None): 726 ModelVariation.__init__(self, name=name) 727 728# Convert operand data type 729class DataTypeConverter(ModelVariation, ImplicitVariation): 730 731 def __init__(self, targetType=None, name=None): 732 ModelVariation.__init__(self, name=name) 733 if targetType is not None: 734 assert DataTypeConverter.IsCompatible(targetType) 735 self.targetType = targetType 736 737 @staticmethod 738 def IsCompatible(value): 739 return value.lower() in ["float16", "int32"] 740 741 def SetToDefaultName(self): 742 if self.targetType is not None: 743 self.name = self.targetType.lower() 744 return self 745 # get all target types 746 targetTypes = list(zip(*self.targetOperands.values()))[0] 747 if "TENSOR_QUANT8_SYMM_PER_CHANNEL" in targetTypes: 748 self.name = "channelQuant8" 749 elif "TENSOR_QUANT8_ASYMM" in targetTypes: 750 self.name = "quant8" 751 elif "TENSOR_INT32" in targetTypes: 752 self.name = "int32" 753 elif "TENSOR_FLOAT16" in targetTypes: 754 self.name = "float16" 755 else: 756 self.name = "float32" 757 return self 758 759 def AutoIdentify(self, model): 760 if self.targetType is not None: 761 # By default, select all the float32 tensors/scalars 762 targets = {op: ["TENSOR_" + self.targetType.upper()] \ 763 for op in model.operands if op.type.type == "TENSOR_FLOAT32"} 764 targets.update({op: [self.targetType.upper()] \ 765 for op in model.operands if op.type.type == "FLOAT32"}) 766 self.Identify(targets) 767 return self 768 769 def TransformOperand(self, op, arg=None): 770 if len(arg) == 1: 771 typeTuple = (arg[0], op.type.dimensions) 772 else: 773 typeTuple = (arg[0], op.type.dimensions, *arg[1:]) 774 # To handle Internal operands 775 if op.value is None or op.type.GetNumberOfElements() == 0: 776 op.type = Type.GetType(*typeTuple) 777 else: 778 v = Dequantize(op.GetValueAsNumpy().astype(np.float32), op.type) 779 op.type = Type.GetType(*typeTuple) 780 v = Quantize(v, op.type) 781 op.SetValueFromNumpy(v) 782 return op 783 784# Convert model to turn on/off relaxed computation 785class RelaxedModeConverter(ModelVariation, ImplicitVariation): 786 787 def __init__(self, isRelaxed=True, name=None): 788 ModelVariation.__init__(self, name=name) 789 if isinstance(isRelaxed, bool): 790 self.isRelaxed = isRelaxed 791 else: 792 assert RelaxedModeConverter.IsCompatible(isRelaxed.lower()) 793 self.isRelaxed = True 794 795 @staticmethod 796 def IsCompatible(value): 797 return value.lower() in ["relaxed"] 798 799 def SetToDefaultName(self): 800 self.name = "relaxed" if self.isRelaxed else "float" 801 return self 802 803 def TransformModel(self, model): 804 model.RelaxedExecution(self.isRelaxed) 805 return model 806 807# Convert data layout between "NHWC" amd "NCHW" 808class DataLayoutConverter(ModelVariation, ImplicitVariation): 809 810 def __init__(self, targetLayout="nchw", name=None): 811 ModelVariation.__init__(self, name=name) 812 self.targetLayout = targetLayout.lower() 813 assert DataLayoutConverter.IsCompatible(self.targetLayout) 814 self.perm = (0, 3, 1, 2) if self.targetLayout == "nchw" else (0, 2, 3, 1) 815 self.param = True if self.targetLayout == "nchw" else False 816 817 @staticmethod 818 def IsCompatible(value): 819 return value.lower() in ["nhwc", "nchw"] 820 821 def SetToDefaultName(self): 822 self.name = self.targetLayout 823 return self 824 825 def TransformOperand(self, op, arg=None): 826 if len(op.type.dimensions) == 4: 827 # To handle Internal operands 828 if op.value is not None and op.type.GetNumberOfElements() != 0: 829 op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm)) 830 newDim = [op.type.dimensions[i] for i in self.perm] 831 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 832 elif len(op.type.dimensions) == 1 and len(op.value) == 4: 833 op.SetValueFromNumpy(op.GetValueAsNumpy()[list(self.perm)]) 834 elif op.type.type == "BOOL": 835 op.SetValue(self.param) 836 else: 837 assert False, "%s not supported by DataLayoutConverter"%op 838 return op 839 840# Convert data by tansposing and removing axis 841class AxisConverter(ModelVariation): 842 843 def __init__(self, origin, target, dim, drop=[], name=None): 844 ModelVariation.__init__(self, name=name) 845 self.origin = origin 846 self.target = target 847 assert all(i >= -dim and i < dim for i in [self.origin, self.target]) 848 self.dim = dim 849 self.perm = list(range(dim)) 850 self.perm.insert(target if target >= 0 else target + dim, self.perm.pop(origin)) 851 self.drop = [drop] if type(drop) is int else list(drop) 852 assert all(i >= -dim and i < dim for i in self.drop) 853 self.drop = [i if i >= 0 else i + dim for i in self.drop] 854 assert target not in self.drop and target + dim not in self.drop 855 856 def SetToDefaultName(self): 857 axis = self.target if self.target >= 0 else self.target + self.dim 858 axis -= sum(i < axis for i in self.drop) 859 neg = "" if self.target >= 0 else "_neg" 860 self.name = "dim%d_axis%d%s"%(self.dim - len(self.drop), axis, neg) 861 return self 862 863 def TransposeAxis(self, op): 864 if op.type.type == "INT32": 865 op.SetValue(self.target) 866 elif len(op.type.dimensions) == self.dim: 867 # To handle Internal operands 868 if op.value is not None: 869 op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm)) 870 newDim = [op.type.dimensions[i] for i in self.perm] 871 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 872 else: 873 assert False, "%s not supported by AxisConverter"%op 874 return op 875 876 def RemoveAxis(self, op): 877 if op.type.type == "INT32": 878 if op.value[0] >= 0: 879 op.SetValue(op.value[0] - sum(i < op.value[0] for i in self.drop)) 880 else: 881 op.SetValue(op.value[0] + sum(i > (op.value[0] + self.dim) for i in self.drop)) 882 elif len(op.type.dimensions) == self.dim: 883 if op.value is not None: 884 val = op.GetValueAsNumpy() 885 for i in sorted(self.drop, reverse=True): 886 val = np.take(val, 0, axis=i) 887 op.SetValueFromNumpy(val) 888 newDim = [op.type.dimensions[i] for i in range(self.dim) if i not in self.drop] 889 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 890 else: 891 assert False, "%s not supported by AxisConverter"%op 892 return op 893 894 def TransformOperand(self, op, arg=None): 895 op = self.TransposeAxis(op) 896 op = self.RemoveAxis(op) 897 return op 898 899# Convert a Parameter to Input 900class ParameterAsInputConverter(ModelVariation, ImplicitVariation): 901 902 def __init__(self, arg="as_input", prefix="weight", name=None): 903 ModelVariation.__init__(self, name=name) 904 assert ParameterAsInputConverter.IsCompatible(arg.lower()) 905 self.prefix = prefix 906 907 @staticmethod 908 def IsCompatible(value): 909 return value.lower() in ["as_input"] 910 911 def SetToDefaultName(self): 912 self.name = self.prefix + "_as_input" 913 return self 914 915 def TransformOperand(self, op, arg=None): 916 assert isinstance(op, Parameter), "%s cannot be converted to Input."%type(op) 917 newop = Input(op.name, op.type.GetSignatureTuple(), skipRenaming=True, extraParams=op.type.extraParams) 918 newop.SetValue(op.value) 919 return newop 920 921# Convert Output based on activation 922class ActivationConverter(ModelVariation, ImplicitVariation): 923 # (Enum, low, high) 924 actMap = { 925 "none": (0, None, None), 926 "relu": (1, 0.0, None), 927 "relu1": (2, -1.0, 1.0), 928 "relu6": (3, 0.0, 6.0), 929 } 930 def __init__(self, act="relu", name=None): 931 ModelVariation.__init__(self, name=name) 932 self.act = act.lower() 933 assert ActivationConverter.IsCompatible(self.act) 934 self.enum = ActivationConverter.actMap[self.act][0] 935 self.low = ActivationConverter.actMap[self.act][1] 936 self.high = ActivationConverter.actMap[self.act][2] 937 938 @staticmethod 939 def IsCompatible(value): 940 return value.lower() in ActivationConverter.actMap.keys() 941 942 def SetToDefaultName(self): 943 self.name = self.act 944 return self 945 946 def TransformOperand(self, op, arg=None): 947 if op.type.type == "INT32": # activation enum 948 return op.SetValue(self.enum) 949 else: 950 assert isinstance(op, Output) 951 v = op.GetValueAsNumpy() 952 if self.low is not None: 953 low = Quantize(self.low, op.type) 954 v = np.maximum(v, low) 955 if self.high is not None: 956 high = Quantize(self.high, op.type) 957 v = np.minimum(v, high) 958 return op.SetValueFromNumpy(v) 959 960class DynamicOutputShapeConverter(ModelVariation): 961 def __init__(self, name=None): 962 ModelVariation.__init__(self, name=name) 963 964 def SetToDefaultName(self): 965 self.name = "dynamic_output_shape" 966 return self 967 968 def TransformModel(self, model): 969 model.TestDynamicOutputShape(True) 970 return model 971 972# An example is always attached to a model, and could have multiple variations 973class Example: 974 examples = [] 975 versionOverrides = {} 976 977 def __init__(self, *args, model=None, name=None): 978 self.model = Model.models[-1] if model is None else model 979 self.name = name 980 self.expectedMultinomialDistributionTolerance = None 981 self.feedDicts = [] 982 for feedDict in args: 983 if type(feedDict) is tuple or type(feedDict) is list: 984 self.feedDicts.append(feedDict) 985 elif type(feedDict) is dict: 986 self.feedDicts.append(( 987 {i: feedDict[i] for i in self.model.GetInputs()}, 988 {o: feedDict[o] for o in self.model.GetOutputs()} 989 )) 990 else: 991 assert False 992 if Configuration.test_dynamic_output_shape: 993 self.variations = [[DefaultVariation(), DynamicOutputShapeConverter()]] 994 else: 995 self.variations = [] 996 Example.examples.append(self) 997 998 @staticmethod 999 def SetVersion(ver, *args): 1000 for name in args: 1001 Example.versionOverrides[name] = ver 1002 1003 # Main entrance of test generator 1004 @staticmethod 1005 def DumpAllExamples(DumpModel=None, model_fd=None, 1006 DumpExample=None, example_fd=None, 1007 DumpTest=None, test_fd=None): 1008 Example.CombineAllExamples() 1009 for example in Example.examples: 1010 example.Dump(DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd) 1011 1012 # Combine examples with the same model, same name, and same set of variations 1013 @staticmethod 1014 def CombineAllExamples(): 1015 modelMap = {} 1016 newExamples = [] 1017 for example in Example.examples: 1018 key = (example.model, example.name, tuple(tuple(e) for e in example.variations)) 1019 if key in modelMap: 1020 modelMap[key].Combine(example) 1021 else: 1022 modelMap[key] = example 1023 newExamples.append(example) 1024 Example.examples = newExamples 1025 1026 def AddVariations(self, *args, includeDefault=True, defaultName=None): 1027 self.variations.append([DefaultVariation(defaultName)] if includeDefault else []) 1028 self.variations[-1].extend(ImplicitVariation.ImplicitConvertion(i) for i in args) 1029 return self 1030 1031 def AddNchw(self, *args, includeDefault=True, defaultName="nhwc"): 1032 var = DataLayoutConverter("nchw").Identify(args) 1033 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1034 return self 1035 1036 def AddRelaxed(self, isRelaxed=True, includeDefault=True, defaultName=None): 1037 var = RelaxedModeConverter(isRelaxed) 1038 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1039 return self 1040 1041 def AddInput(self, *args, includeDefault=True, defaultName=None): 1042 var = ParameterAsInputConverter().Identify(args) 1043 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1044 return self 1045 1046 def AddRelu(self, *args, includeDefault=True, defaultName=None): 1047 var = ActivationConverter("relu").Identify(args) 1048 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1049 return self 1050 1051 def AddAllActivations(self, *args): 1052 var = [ActivationConverter(i).Identify(args) 1053 for i in sorted(ActivationConverter.actMap.keys())] 1054 self.AddVariations(*var, includeDefault=False) 1055 return self 1056 1057 def GuessOriginalAxisAndDim(self, *args): 1058 origin = None 1059 dim = None 1060 for arg in args: 1061 if arg.type.type == "INT32": 1062 origin = arg.value[0] 1063 else: 1064 if dim is None: 1065 dim = len(arg.type.dimensions) 1066 else: 1067 assert dim == len(arg.type.dimensions) 1068 assert dim is not None 1069 origin = dim - 1 if origin is None else origin 1070 origin = origin + dim if origin < 0 else origin 1071 return origin, dim 1072 1073 def AddAxis(self, axis, *args, includeDefault=True, defaultName=None): 1074 origin, dim = self.GuessOriginalAxisAndDim(*args) 1075 axis = [axis] if type(axis) is int else list(axis) 1076 var = [AxisConverter(origin, a, dim).Identify(args) for a in axis] 1077 self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName) 1078 return self 1079 1080 def AddAllPositiveAxis(self, *args): 1081 origin, dim = self.GuessOriginalAxisAndDim(*args) 1082 var = [AxisConverter(origin, a, dim).Identify(args) for a in range(dim)] 1083 self.AddVariations(*var, includeDefault=False) 1084 return self 1085 1086 def AddAllAxis(self, *args): 1087 origin, dim = self.GuessOriginalAxisAndDim(*args) 1088 var = [AxisConverter(origin, a, dim).Identify(args) for a in range(-dim, dim)] 1089 self.AddVariations(*var, includeDefault=False) 1090 return self 1091 1092 def AddDims(self, dims, *args, includeDefault=True, defaultName=None): 1093 origin, dim = self.GuessOriginalAxisAndDim(*args) 1094 dims = [dims] if type(dims) is int else list(dims) 1095 drop = list(range(dim)) 1096 drop.pop(origin) 1097 var = [AxisConverter(origin, origin, dim, drop[0:(dim-i)]).Identify(args) for i in dims] 1098 self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName) 1099 return self 1100 1101 def AddAllDims(self, *args): 1102 origin, dim = self.GuessOriginalAxisAndDim(*args) 1103 drop = list(range(dim)) 1104 drop.pop(origin) 1105 var = [AxisConverter(origin, origin, dim, drop[0:i]).Identify(args) for i in range(dim)] 1106 self.AddVariations(*var, includeDefault=False) 1107 return self 1108 1109 def AddAllDimsAndPositiveAxis(self, *args): 1110 origin, dim = self.GuessOriginalAxisAndDim(*args) 1111 var = [AxisConverter(origin, j, dim, range(i)).Identify(args) \ 1112 for i in range(dim) for j in range(i, dim)] 1113 self.AddVariations(*var, includeDefault=False) 1114 return self 1115 1116 def AddAllDimsAndAxis(self, *args): 1117 origin, dim = self.GuessOriginalAxisAndDim(*args) 1118 var = [AxisConverter(origin, k, dim, range(i)).Identify(args) \ 1119 for i in range(dim) for j in range(i, dim) for k in [j, j - dim]] 1120 self.AddVariations(*var, includeDefault=False) 1121 return self 1122 1123 def Combine(self, other): 1124 assert self.model is other.model, "Only examples targetting the same model can be combined" 1125 assert tuple(self.variations) == tuple(other.variations), \ 1126 "Only examples with the same set of variations can be combined" 1127 assert self.name == other.name, "Only examples with the same name can be combined" 1128 self.feedDicts.extend(other.feedDicts) 1129 return self 1130 1131 def Dump(self, DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd): 1132 [v.SetToDefaultName() for vs in self.variations for v in vs if v.name is None] 1133 for variationList in itertools.product(*self.variations): 1134 # Apply variations 1135 modelOrigin, feedDictsOrigin = self.model, self.feedDicts 1136 self.model, self.feedDicts = copy.deepcopy((self.model, self.feedDicts)) 1137 for variation in variationList: 1138 self.model, self.feedDicts = variation.ApplyTo(self.model, self.feedDicts) 1139 # Concat names for test and examples 1140 varNames = [v.name for v in variationList] 1141 self.testName = NamedTest(FileNames.specName, self.model.name, self.name, *varNames) 1142 self.examplesName = GlobalVariable("examples", self.model.name, self.name, *varNames) 1143 if str(self.testName) in Example.versionOverrides: 1144 self.model.IntroducedIn(Example.versionOverrides[str(self.testName)]) 1145 self.model.WithSuffix(*varNames).Compile() 1146 # Dump files 1147 if DumpModel is not None and model_fd is not None: 1148 DumpModel(self.model, model_fd) 1149 if DumpExample is not None and example_fd is not None: 1150 DumpExample(self, example_fd) 1151 if DumpTest is not None and test_fd is not None: 1152 DumpTest(self, test_fd) 1153 # Restore model and feedDicts before variation 1154 self.model = modelOrigin 1155 self.feedDicts = feedDictsOrigin 1156 return self 1157 1158 # Specifies the RANDOM_MULTINOMIAL distribution tolerance. 1159 # If set to greater than zero, the input is compared as log-probabilities 1160 # to the output and must be within this tolerance to pass. 1161 def WithMultinomialDistributionTolerance(self, expectedTolerance): 1162 self.expectedMultinomialDistributionTolerance = expectedTolerance 1163 return self 1164 1165 # For backward-compatibility with slicing.py 1166 # Similar to dump_dict, but in python. Used by the slicing tool 1167 # if referenced is not None, only print operands that are present there 1168 @staticmethod 1169 def py_dump_dict(d, referenced): 1170 ret = [] 1171 for k, v in d.items(): 1172 if referenced != None and k not in referenced: 1173 continue 1174 key = str(k) 1175 init = pprint.pformat(v) 1176 ret.append("%s: %s" % (key, init)) 1177 return ", ".join(ret) 1178 1179 # For backward-compatibility with slicing.py 1180 # similar to dump, but in python. Used by the slicing tool 1181 # if referenced is not None, only print operands that are present there 1182 @staticmethod 1183 def py_dump(example_file, override, referenced): 1184 Example.CombineAllExamples() 1185 if len(Example.examples[0].feedDicts) > 0: 1186 example_no = 0 1187 example_template = """\ 1188input{no} = {{{inputs}}} 1189# Only executed during data collection phase 1190if collecting_data is True: 1191 Example((input{no}, {{{outputs}}})) 1192""" 1193 for i, o in Example.examples[0].feedDicts: 1194 print ('# Begin of an example', file = example_file) 1195 inputs = Example.py_dump_dict(i, referenced) 1196 output_list = [] 1197 for k, v in override.items(): 1198 output_list.append("%s: [0] * %d" % (k, v)) 1199 outputs = ",".join(output_list) 1200 1201 # TODO: handle >1 outputs 1202 for k, v in o.items(): 1203 assert k.index == 0 1204 example_contents = { 1205 'no': example_no, 1206 'inputs': inputs, 1207 'outputs': outputs 1208 } 1209 print (example_template.format(**example_contents), file = example_file) 1210 1211class FileNames: 1212 specFiles = [] 1213 specNames = [] 1214 modelFiles = [] 1215 exampleFiles = [] 1216 testFiles = [] 1217 specFile = "" 1218 specName = "" 1219 modelFile = "" 1220 exampleFile = "" 1221 testFile = "" 1222 ctsFile = "" 1223 logFile = "" 1224 version = "" 1225 fileIndex = 0 1226 1227 @staticmethod 1228 def InitializeFileLists(spec, model, example, test, cts="-", log=""): 1229 # get all spec files and target files 1230 if os.path.isfile(spec): 1231 FileNames.specFiles = [os.path.abspath(spec)] 1232 elif os.path.isdir(spec): 1233 FileNames.specFiles = sorted([os.path.abspath(os.path.join(spec, f)) 1234 for f in os.listdir(spec) if f.endswith(".mod.py")]) 1235 else: 1236 assert False, "%s is neither a file or a directory"%spec 1237 FileNames.specNames = [re.sub(r"\..*", "", os.path.basename(f)) 1238 for f in FileNames.specFiles] 1239 FileNames.modelFiles = FileNames.ParseTargetFiles(model, ".model.cpp") 1240 FileNames.exampleFiles = FileNames.ParseTargetFiles(example, ".example.cpp") 1241 FileNames.testFiles = FileNames.ParseTargetFiles(test, ".mod.py.cpp") 1242 FileNames.ctsFile = os.path.abspath(cts) if cts != "-" else "-" 1243 FileNames.logFile = ", \"%s\""%log if log != "" else "" 1244 1245 @staticmethod 1246 def ParseTargetFiles(arg, ext): 1247 numFiles = len(FileNames.specFiles) 1248 absPath = os.path.abspath(arg) 1249 if os.path.isdir(arg): 1250 target = [os.path.join(absPath, f + ext) for f in FileNames.specNames] 1251 elif arg == "-": 1252 target = ["-"] * numFiles 1253 else: 1254 target = [absPath] * numFiles 1255 return target 1256 1257 @staticmethod 1258 def NextFile(): 1259 if FileNames.fileIndex >= len(FileNames.specFiles): 1260 return False 1261 FileNames.specFile = FileNames.specFiles[FileNames.fileIndex] 1262 FileNames.specName = FileNames.specNames[FileNames.fileIndex] 1263 FileNames.modelFile = FileNames.modelFiles[FileNames.fileIndex] 1264 FileNames.exampleFile = FileNames.exampleFiles[FileNames.fileIndex] 1265 FileNames.testFile = FileNames.testFiles[FileNames.fileIndex] 1266 FileNames.fileIndex += 1 1267 NamedObject.existingNames = set() 1268 NamedVariable.existingNames = set() 1269 NamedTest.existingNames = set() 1270 Type.typesMap = dict() 1271 Model.models = list() 1272 Example.examples = list() 1273 Configuration.use_shm_for_weights = False 1274 1275 # Extract version from absolute file path. 1276 versionMatch = re.findall(r"/V\d_\d/", FileNames.specFile) 1277 if len(versionMatch) == 1: 1278 FileNames.version = versionMatch[0].strip('/') 1279 else: 1280 FileNames.version = None 1281 return True 1282 1283class Configuration: 1284 use_shm_for_weights = False 1285 force_regenerate = False 1286 test_dynamic_output_shape = True 1287 1288 @staticmethod 1289 def useSHM(): 1290 return Configuration.use_shm_for_weights 1291