1"""Generic visitor pattern implementation for Python objects.""" 2 3import enum 4 5 6class Visitor(object): 7 8 defaultStop = False 9 10 @classmethod 11 def _register(celf, clazzes_attrs): 12 assert celf != Visitor, "Subclass Visitor instead." 13 if "_visitors" not in celf.__dict__: 14 celf._visitors = {} 15 16 def wrapper(method): 17 assert method.__name__ == "visit" 18 for clazzes, attrs in clazzes_attrs: 19 if type(clazzes) != tuple: 20 clazzes = (clazzes,) 21 if type(attrs) == str: 22 attrs = (attrs,) 23 for clazz in clazzes: 24 _visitors = celf._visitors.setdefault(clazz, {}) 25 for attr in attrs: 26 assert attr not in _visitors, ( 27 "Oops, class '%s' has visitor function for '%s' defined already." 28 % (clazz.__name__, attr) 29 ) 30 _visitors[attr] = method 31 return None 32 33 return wrapper 34 35 @classmethod 36 def register(celf, clazzes): 37 if type(clazzes) != tuple: 38 clazzes = (clazzes,) 39 return celf._register([(clazzes, (None,))]) 40 41 @classmethod 42 def register_attr(celf, clazzes, attrs): 43 clazzes_attrs = [] 44 if type(clazzes) != tuple: 45 clazzes = (clazzes,) 46 if type(attrs) == str: 47 attrs = (attrs,) 48 for clazz in clazzes: 49 clazzes_attrs.append((clazz, attrs)) 50 return celf._register(clazzes_attrs) 51 52 @classmethod 53 def register_attrs(celf, clazzes_attrs): 54 return celf._register(clazzes_attrs) 55 56 @classmethod 57 def _visitorsFor(celf, thing, _default={}): 58 typ = type(thing) 59 60 for celf in celf.mro(): 61 62 _visitors = getattr(celf, "_visitors", None) 63 if _visitors is None: 64 break 65 66 m = celf._visitors.get(typ, None) 67 if m is not None: 68 return m 69 70 return _default 71 72 def visitObject(self, obj, *args, **kwargs): 73 """Called to visit an object. This function loops over all non-private 74 attributes of the objects and calls any user-registered (via 75 @register_attr() or @register_attrs()) visit() functions. 76 77 If there is no user-registered visit function, of if there is and it 78 returns True, or it returns None (or doesn't return anything) and 79 visitor.defaultStop is False (default), then the visitor will proceed 80 to call self.visitAttr()""" 81 82 keys = sorted(vars(obj).keys()) 83 _visitors = self._visitorsFor(obj) 84 defaultVisitor = _visitors.get("*", None) 85 for key in keys: 86 if key[0] == "_": 87 continue 88 value = getattr(obj, key) 89 visitorFunc = _visitors.get(key, defaultVisitor) 90 if visitorFunc is not None: 91 ret = visitorFunc(self, obj, key, value, *args, **kwargs) 92 if ret == False or (ret is None and self.defaultStop): 93 continue 94 self.visitAttr(obj, key, value, *args, **kwargs) 95 96 def visitAttr(self, obj, attr, value, *args, **kwargs): 97 """Called to visit an attribute of an object.""" 98 self.visit(value, *args, **kwargs) 99 100 def visitList(self, obj, *args, **kwargs): 101 """Called to visit any value that is a list.""" 102 for value in obj: 103 self.visit(value, *args, **kwargs) 104 105 def visitDict(self, obj, *args, **kwargs): 106 """Called to visit any value that is a dictionary.""" 107 for value in obj.values(): 108 self.visit(value, *args, **kwargs) 109 110 def visitLeaf(self, obj, *args, **kwargs): 111 """Called to visit any value that is not an object, list, 112 or dictionary.""" 113 pass 114 115 def visit(self, obj, *args, **kwargs): 116 """This is the main entry to the visitor. The visitor will visit object 117 obj. 118 119 The visitor will first determine if there is a registered (via 120 @register()) visit function for the type of object. If there is, it 121 will be called, and (visitor, obj, *args, **kwargs) will be passed to 122 the user visit function. 123 124 If there is no user-registered visit function, of if there is and it 125 returns True, or it returns None (or doesn't return anything) and 126 visitor.defaultStop is False (default), then the visitor will proceed 127 to dispatch to one of self.visitObject(), self.visitList(), 128 self.visitDict(), or self.visitLeaf() (any of which can be overriden in 129 a subclass).""" 130 131 visitorFunc = self._visitorsFor(obj).get(None, None) 132 if visitorFunc is not None: 133 ret = visitorFunc(self, obj, *args, **kwargs) 134 if ret == False or (ret is None and self.defaultStop): 135 return 136 if hasattr(obj, "__dict__") and not isinstance(obj, enum.Enum): 137 self.visitObject(obj, *args, **kwargs) 138 elif isinstance(obj, list): 139 self.visitList(obj, *args, **kwargs) 140 elif isinstance(obj, dict): 141 self.visitDict(obj, *args, **kwargs) 142 else: 143 self.visitLeaf(obj, *args, **kwargs) 144