1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Global configuration support.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum 22 23 24# TODO(mdan): For better performance, allow each rule to take a set names. 25 26 27class Rule(object): 28 """Base class for conversion rules.""" 29 30 def __init__(self, module_prefix): 31 self._prefix = module_prefix 32 33 def matches(self, module_name): 34 return (module_name.startswith(self._prefix + '.') or 35 module_name == self._prefix) 36 37 38class Action(enum.Enum): 39 NONE = 0 40 CONVERT = 1 41 DO_NOT_CONVERT = 2 42 43 44class DoNotConvert(Rule): 45 """Indicates that this module should be not converted.""" 46 47 def __str__(self): 48 return 'DoNotConvert rule for {}'.format(self._prefix) 49 50 def get_action(self, module): 51 if self.matches(module.__name__): 52 return Action.DO_NOT_CONVERT 53 return Action.NONE 54 55 56class Convert(Rule): 57 """Indicates that this module should be converted.""" 58 59 def __str__(self): 60 return 'Convert rule for {}'.format(self._prefix) 61 62 def get_action(self, module): 63 if self.matches(module.__name__): 64 return Action.CONVERT 65 return Action.NONE 66