1from __future__ import absolute_import, division, unicode_literals 2from six import text_type 3 4from lxml import etree 5from ..treebuilders.etree import tag_regexp 6 7from . import _base 8 9from .. import ihatexml 10 11 12def ensure_str(s): 13 if s is None: 14 return None 15 elif isinstance(s, text_type): 16 return s 17 else: 18 return s.decode("utf-8", "strict") 19 20 21class Root(object): 22 def __init__(self, et): 23 self.elementtree = et 24 self.children = [] 25 if et.docinfo.internalDTD: 26 self.children.append(Doctype(self, 27 ensure_str(et.docinfo.root_name), 28 ensure_str(et.docinfo.public_id), 29 ensure_str(et.docinfo.system_url))) 30 root = et.getroot() 31 node = root 32 33 while node.getprevious() is not None: 34 node = node.getprevious() 35 while node is not None: 36 self.children.append(node) 37 node = node.getnext() 38 39 self.text = None 40 self.tail = None 41 42 def __getitem__(self, key): 43 return self.children[key] 44 45 def getnext(self): 46 return None 47 48 def __len__(self): 49 return 1 50 51 52class Doctype(object): 53 def __init__(self, root_node, name, public_id, system_id): 54 self.root_node = root_node 55 self.name = name 56 self.public_id = public_id 57 self.system_id = system_id 58 59 self.text = None 60 self.tail = None 61 62 def getnext(self): 63 return self.root_node.children[1] 64 65 66class FragmentRoot(Root): 67 def __init__(self, children): 68 self.children = [FragmentWrapper(self, child) for child in children] 69 self.text = self.tail = None 70 71 def getnext(self): 72 return None 73 74 75class FragmentWrapper(object): 76 def __init__(self, fragment_root, obj): 77 self.root_node = fragment_root 78 self.obj = obj 79 if hasattr(self.obj, 'text'): 80 self.text = ensure_str(self.obj.text) 81 else: 82 self.text = None 83 if hasattr(self.obj, 'tail'): 84 self.tail = ensure_str(self.obj.tail) 85 else: 86 self.tail = None 87 88 def __getattr__(self, name): 89 return getattr(self.obj, name) 90 91 def getnext(self): 92 siblings = self.root_node.children 93 idx = siblings.index(self) 94 if idx < len(siblings) - 1: 95 return siblings[idx + 1] 96 else: 97 return None 98 99 def __getitem__(self, key): 100 return self.obj[key] 101 102 def __bool__(self): 103 return bool(self.obj) 104 105 def getparent(self): 106 return None 107 108 def __str__(self): 109 return str(self.obj) 110 111 def __unicode__(self): 112 return str(self.obj) 113 114 def __len__(self): 115 return len(self.obj) 116 117 118class TreeWalker(_base.NonRecursiveTreeWalker): 119 def __init__(self, tree): 120 if hasattr(tree, "getroot"): 121 tree = Root(tree) 122 elif isinstance(tree, list): 123 tree = FragmentRoot(tree) 124 _base.NonRecursiveTreeWalker.__init__(self, tree) 125 self.filter = ihatexml.InfosetFilter() 126 127 def getNodeDetails(self, node): 128 if isinstance(node, tuple): # Text node 129 node, key = node 130 assert key in ("text", "tail"), "Text nodes are text or tail, found %s" % key 131 return _base.TEXT, ensure_str(getattr(node, key)) 132 133 elif isinstance(node, Root): 134 return (_base.DOCUMENT,) 135 136 elif isinstance(node, Doctype): 137 return _base.DOCTYPE, node.name, node.public_id, node.system_id 138 139 elif isinstance(node, FragmentWrapper) and not hasattr(node, "tag"): 140 return _base.TEXT, node.obj 141 142 elif node.tag == etree.Comment: 143 return _base.COMMENT, ensure_str(node.text) 144 145 elif node.tag == etree.Entity: 146 return _base.ENTITY, ensure_str(node.text)[1:-1] # strip &; 147 148 else: 149 # This is assumed to be an ordinary element 150 match = tag_regexp.match(ensure_str(node.tag)) 151 if match: 152 namespace, tag = match.groups() 153 else: 154 namespace = None 155 tag = ensure_str(node.tag) 156 attrs = {} 157 for name, value in list(node.attrib.items()): 158 name = ensure_str(name) 159 value = ensure_str(value) 160 match = tag_regexp.match(name) 161 if match: 162 attrs[(match.group(1), match.group(2))] = value 163 else: 164 attrs[(None, name)] = value 165 return (_base.ELEMENT, namespace, self.filter.fromXmlName(tag), 166 attrs, len(node) > 0 or node.text) 167 168 def getFirstChild(self, node): 169 assert not isinstance(node, tuple), "Text nodes have no children" 170 171 assert len(node) or node.text, "Node has no children" 172 if node.text: 173 return (node, "text") 174 else: 175 return node[0] 176 177 def getNextSibling(self, node): 178 if isinstance(node, tuple): # Text node 179 node, key = node 180 assert key in ("text", "tail"), "Text nodes are text or tail, found %s" % key 181 if key == "text": 182 # XXX: we cannot use a "bool(node) and node[0] or None" construct here 183 # because node[0] might evaluate to False if it has no child element 184 if len(node): 185 return node[0] 186 else: 187 return None 188 else: # tail 189 return node.getnext() 190 191 return (node, "tail") if node.tail else node.getnext() 192 193 def getParentNode(self, node): 194 if isinstance(node, tuple): # Text node 195 node, key = node 196 assert key in ("text", "tail"), "Text nodes are text or tail, found %s" % key 197 if key == "text": 198 return node 199 # else: fallback to "normal" processing 200 201 return node.getparent() 202