1"""Code generator for Code Completion Model Inference. 2 3Tool runs on the Decision Forest model defined in {model} directory. 4It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp 5The generated files defines the Example class named {cpp_class} having all the features as class members. 6The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate. 7""" 8 9import argparse 10import json 11import struct 12 13 14class CppClass: 15 """Holds class name and names of the enclosing namespaces.""" 16 17 def __init__(self, cpp_class): 18 ns_and_class = cpp_class.split("::") 19 self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0] 20 self.name = ns_and_class[-1] 21 if len(self.name) == 0: 22 raise ValueError("Empty class name.") 23 24 def ns_begin(self): 25 """Returns snippet for opening namespace declarations.""" 26 open_ns = ["namespace %s {" % ns for ns in self.ns] 27 return "\n".join(open_ns) 28 29 def ns_end(self): 30 """Returns snippet for closing namespace declarations.""" 31 close_ns = [ 32 "} // namespace %s" % ns for ns in reversed(self.ns)] 33 return "\n".join(close_ns) 34 35 36def header_guard(filename): 37 '''Returns the header guard for the generated header.''' 38 return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper() 39 40 41def boost_node(n, label, next_label): 42 """Returns code snippet for a leaf/boost node.""" 43 return "%s: return %sf;" % (label, n['score']) 44 45 46def if_greater_node(n, label, next_label): 47 """Returns code snippet for a if_greater node. 48 Jumps to true_label if the Example feature (NUMBER) is greater than the threshold. 49 Comparing integers is much faster than comparing floats. Assuming floating points 50 are represented as IEEE 754, it order-encodes the floats to integers before comparing them. 51 Control falls through if condition is evaluated to false.""" 52 threshold = n["threshold"] 53 return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % ( 54 label, n['feature'], order_encode(threshold), threshold, next_label) 55 56 57def if_member_node(n, label, next_label): 58 """Returns code snippet for a if_member node. 59 Jumps to true_label if the Example feature (ENUM) is present in the set of enum values 60 described in the node. 61 Control falls through if condition is evaluated to false.""" 62 members = '|'.join([ 63 "BIT(%s_type::%s)" % (n['feature'], member) 64 for member in n["set"] 65 ]) 66 return "%s: if (E.get%s() & (%s)) goto %s;" % ( 67 label, n['feature'], members, next_label) 68 69 70def node(n, label, next_label): 71 """Returns code snippet for the node.""" 72 return { 73 'boost': boost_node, 74 'if_greater': if_greater_node, 75 'if_member': if_member_node, 76 }[n['operation']](n, label, next_label) 77 78 79def tree(t, tree_num, node_num): 80 """Returns code for inferencing a Decision Tree. 81 Also returns the size of the decision tree. 82 83 A tree starts with its label `t{tree#}`. 84 A node of the tree starts with label `t{tree#}_n{node#}`. 85 86 The tree contains two types of node: Conditional node and Leaf node. 87 - Conditional node evaluates a condition. If true, it jumps to the true node/child. 88 Code is generated using pre-order traversal of the tree considering 89 false node as the first child. Therefore the false node is always the 90 immediately next label. 91 - Leaf node adds the value to the score and jumps to the next tree. 92 """ 93 label = "t%d_n%d" % (tree_num, node_num) 94 code = [] 95 96 if t["operation"] == "boost": 97 code.append(node(t, label=label, next_label="t%d" % (tree_num + 1))) 98 return code, 1 99 100 false_code, false_size = tree( 101 t['else'], tree_num=tree_num, node_num=node_num+1) 102 103 true_node_num = node_num+false_size+1 104 true_label = "t%d_n%d" % (tree_num, true_node_num) 105 106 true_code, true_size = tree( 107 t['then'], tree_num=tree_num, node_num=true_node_num) 108 109 code.append(node(t, label=label, next_label=true_label)) 110 111 return code+false_code+true_code, 1+false_size+true_size 112 113 114def gen_header_code(features_json, cpp_class, filename): 115 """Returns code for header declaring the inference runtime. 116 117 Declares the Example class named {cpp_class} inside relevant namespaces. 118 The Example class contains all the features as class members. This 119 class can be used to represent a code completion candidate. 120 Provides `float Evaluate()` function which can be used to score the Example. 121 """ 122 setters = [] 123 getters = [] 124 for f in features_json: 125 feature = f["name"] 126 127 if f["kind"] == "NUMBER": 128 # Floats are order-encoded to integers for faster comparison. 129 setters.append( 130 "void set%s(float V) { %s = OrderEncode(V); }" % ( 131 feature, feature)) 132 elif f["kind"] == "ENUM": 133 setters.append( 134 "void set%s(unsigned V) { %s = 1 << V; }" % (feature, feature)) 135 else: 136 raise ValueError("Unhandled feature type.", f["kind"]) 137 138 # Class members represent all the features of the Example. 139 class_members = [ 140 "uint32_t %s = 0;" % f['name'] 141 for f in features_json 142 ] 143 getters = [ 144 "LLVM_ATTRIBUTE_ALWAYS_INLINE uint32_t get%s() const { return %s; }" 145 % (f['name'], f['name']) 146 for f in features_json 147 ] 148 nline = "\n " 149 guard = header_guard(filename) 150 return """#ifndef %s 151#define %s 152#include <cstdint> 153#include "llvm/Support/Compiler.h" 154 155%s 156class %s { 157public: 158 // Setters. 159 %s 160 161 // Getters. 162 %s 163 164private: 165 %s 166 167 // Produces an integer that sorts in the same order as F. 168 // That is: a < b <==> orderEncode(a) < orderEncode(b). 169 static uint32_t OrderEncode(float F); 170}; 171 172float Evaluate(const %s&); 173%s 174#endif // %s 175""" % (guard, guard, cpp_class.ns_begin(), cpp_class.name, 176 nline.join(setters), 177 nline.join(getters), 178 nline.join(class_members), 179 cpp_class.name, cpp_class.ns_end(), guard) 180 181 182def order_encode(v): 183 i = struct.unpack('<I', struct.pack('<f', v))[0] 184 TopBit = 1 << 31 185 # IEEE 754 floats compare like sign-magnitude integers. 186 if (i & TopBit): # Negative float 187 return (1 << 32) - i # low half of integers, order reversed. 188 return TopBit + i # top half of integers 189 190 191def evaluate_func(forest_json, cpp_class): 192 """Generates evaluation functions for each tree and combines them in 193 `float Evaluate(const {Example}&)` function. This function can be 194 used to score an Example.""" 195 196 code = "" 197 198 # Generate evaluation function of each tree. 199 code += "namespace {\n" 200 tree_num = 0 201 for tree_json in forest_json: 202 code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (tree_num, cpp_class.name) 203 code += " " + \ 204 "\n ".join( 205 tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n" 206 code += "}\n\n" 207 tree_num += 1 208 code += "} // namespace\n\n" 209 210 # Combine the scores of all trees in the final function. 211 # MSAN will timeout if these functions are inlined. 212 code += "float Evaluate(const %s& E) {\n" % cpp_class.name 213 code += " float Score = 0;\n" 214 for tree_num in range(len(forest_json)): 215 code += " Score += EvaluateTree%d(E);\n" % tree_num 216 code += " return Score;\n" 217 code += "}\n" 218 219 return code 220 221 222def gen_cpp_code(forest_json, features_json, filename, cpp_class): 223 """Generates code for the .cpp file.""" 224 # Headers 225 # Required by OrderEncode(float F). 226 angled_include = [ 227 '#include <%s>' % h 228 for h in ["cstring", "limits"] 229 ] 230 231 # Include generated header. 232 qouted_headers = {filename + '.h', 'llvm/ADT/bit.h'} 233 # Headers required by ENUM features used by the model. 234 qouted_headers |= {f["header"] 235 for f in features_json if f["kind"] == "ENUM"} 236 quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)] 237 238 # using-decl for ENUM features. 239 using_decls = "\n".join("using %s_type = %s;" % ( 240 feature['name'], feature['type']) 241 for feature in features_json 242 if feature["kind"] == "ENUM") 243 nl = "\n" 244 return """%s 245 246%s 247 248#define BIT(X) (1 << X) 249 250%s 251 252%s 253 254uint32_t %s::OrderEncode(float F) { 255 static_assert(std::numeric_limits<float>::is_iec559, ""); 256 constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1); 257 258 // Get the bits of the float. Endianness is the same as for integers. 259 uint32_t U = llvm::bit_cast<uint32_t>(F); 260 std::memcpy(&U, &F, sizeof(U)); 261 // IEEE 754 floats compare like sign-magnitude integers. 262 if (U & TopBit) // Negative float. 263 return 0 - U; // Map onto the low half of integers, order reversed. 264 return U + TopBit; // Positive floats map onto the high half of integers. 265} 266 267%s 268%s 269""" % (nl.join(angled_include), nl.join(quoted_include), cpp_class.ns_begin(), 270 using_decls, cpp_class.name, evaluate_func(forest_json, cpp_class), 271 cpp_class.ns_end()) 272 273 274def main(): 275 parser = argparse.ArgumentParser('DecisionForestCodegen') 276 parser.add_argument('--filename', help='output file name.') 277 parser.add_argument('--output_dir', help='output directory.') 278 parser.add_argument('--model', help='path to model directory.') 279 parser.add_argument( 280 '--cpp_class', 281 help='The name of the class (which may be a namespace-qualified) created in generated header.' 282 ) 283 ns = parser.parse_args() 284 285 output_dir = ns.output_dir 286 filename = ns.filename 287 header_file = "%s/%s.h" % (output_dir, filename) 288 cpp_file = "%s/%s.cpp" % (output_dir, filename) 289 cpp_class = CppClass(cpp_class=ns.cpp_class) 290 291 model_file = "%s/forest.json" % ns.model 292 features_file = "%s/features.json" % ns.model 293 294 with open(features_file) as f: 295 features_json = json.load(f) 296 297 with open(model_file) as m: 298 forest_json = json.load(m) 299 300 with open(cpp_file, 'w+t') as output_cc: 301 output_cc.write( 302 gen_cpp_code(forest_json=forest_json, 303 features_json=features_json, 304 filename=filename, 305 cpp_class=cpp_class)) 306 307 with open(header_file, 'w+t') as output_h: 308 output_h.write(gen_header_code( 309 features_json=features_json, 310 cpp_class=cpp_class, 311 filename=filename)) 312 313 314if __name__ == '__main__': 315 main() 316