• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Generic visitor pattern implementation for Python objects."""
2
3import enum
4
5
6class Visitor(object):
7
8    defaultStop = False
9
10    @classmethod
11    def _register(celf, clazzes_attrs):
12        assert celf != Visitor, "Subclass Visitor instead."
13        if "_visitors" not in celf.__dict__:
14            celf._visitors = {}
15
16        def wrapper(method):
17            assert method.__name__ == "visit"
18            for clazzes, attrs in clazzes_attrs:
19                if type(clazzes) != tuple:
20                    clazzes = (clazzes,)
21                if type(attrs) == str:
22                    attrs = (attrs,)
23                for clazz in clazzes:
24                    _visitors = celf._visitors.setdefault(clazz, {})
25                    for attr in attrs:
26                        assert attr not in _visitors, (
27                            "Oops, class '%s' has visitor function for '%s' defined already."
28                            % (clazz.__name__, attr)
29                        )
30                        _visitors[attr] = method
31            return None
32
33        return wrapper
34
35    @classmethod
36    def register(celf, clazzes):
37        if type(clazzes) != tuple:
38            clazzes = (clazzes,)
39        return celf._register([(clazzes, (None,))])
40
41    @classmethod
42    def register_attr(celf, clazzes, attrs):
43        clazzes_attrs = []
44        if type(clazzes) != tuple:
45            clazzes = (clazzes,)
46        if type(attrs) == str:
47            attrs = (attrs,)
48        for clazz in clazzes:
49            clazzes_attrs.append((clazz, attrs))
50        return celf._register(clazzes_attrs)
51
52    @classmethod
53    def register_attrs(celf, clazzes_attrs):
54        return celf._register(clazzes_attrs)
55
56    @classmethod
57    def _visitorsFor(celf, thing, _default={}):
58        typ = type(thing)
59
60        for celf in celf.mro():
61
62            _visitors = getattr(celf, "_visitors", None)
63            if _visitors is None:
64                break
65
66            m = celf._visitors.get(typ, None)
67            if m is not None:
68                return m
69
70        return _default
71
72    def visitObject(self, obj, *args, **kwargs):
73        """Called to visit an object. This function loops over all non-private
74        attributes of the objects and calls any user-registered (via
75        @register_attr() or @register_attrs()) visit() functions.
76
77        If there is no user-registered visit function, of if there is and it
78        returns True, or it returns None (or doesn't return anything) and
79        visitor.defaultStop is False (default), then the visitor will proceed
80        to call self.visitAttr()"""
81
82        keys = sorted(vars(obj).keys())
83        _visitors = self._visitorsFor(obj)
84        defaultVisitor = _visitors.get("*", None)
85        for key in keys:
86            if key[0] == "_":
87                continue
88            value = getattr(obj, key)
89            visitorFunc = _visitors.get(key, defaultVisitor)
90            if visitorFunc is not None:
91                ret = visitorFunc(self, obj, key, value, *args, **kwargs)
92                if ret == False or (ret is None and self.defaultStop):
93                    continue
94            self.visitAttr(obj, key, value, *args, **kwargs)
95
96    def visitAttr(self, obj, attr, value, *args, **kwargs):
97        """Called to visit an attribute of an object."""
98        self.visit(value, *args, **kwargs)
99
100    def visitList(self, obj, *args, **kwargs):
101        """Called to visit any value that is a list."""
102        for value in obj:
103            self.visit(value, *args, **kwargs)
104
105    def visitDict(self, obj, *args, **kwargs):
106        """Called to visit any value that is a dictionary."""
107        for value in obj.values():
108            self.visit(value, *args, **kwargs)
109
110    def visitLeaf(self, obj, *args, **kwargs):
111        """Called to visit any value that is not an object, list,
112        or dictionary."""
113        pass
114
115    def visit(self, obj, *args, **kwargs):
116        """This is the main entry to the visitor. The visitor will visit object
117        obj.
118
119        The visitor will first determine if there is a registered (via
120        @register()) visit function for the type of object. If there is, it
121        will be called, and (visitor, obj, *args, **kwargs) will be passed to
122        the user visit function.
123
124        If there is no user-registered visit function, of if there is and it
125        returns True, or it returns None (or doesn't return anything) and
126        visitor.defaultStop is False (default), then the visitor will proceed
127        to dispatch to one of self.visitObject(), self.visitList(),
128        self.visitDict(), or self.visitLeaf() (any of which can be overriden in
129        a subclass)."""
130
131        visitorFunc = self._visitorsFor(obj).get(None, None)
132        if visitorFunc is not None:
133            ret = visitorFunc(self, obj, *args, **kwargs)
134            if ret == False or (ret is None and self.defaultStop):
135                return
136        if hasattr(obj, "__dict__") and not isinstance(obj, enum.Enum):
137            self.visitObject(obj, *args, **kwargs)
138        elif isinstance(obj, list):
139            self.visitList(obj, *args, **kwargs)
140        elif isinstance(obj, dict):
141            self.visitDict(obj, *args, **kwargs)
142        else:
143            self.visitLeaf(obj, *args, **kwargs)
144