1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Random code generation for testing/fuzzing.""" 16# pylint: disable=invalid-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import random 22import string 23 24import gast 25import numpy as np 26 27from tensorflow.python.autograph.pyct import templates 28 29 30class NodeSampler(object): 31 sample_map = None 32 33 def sample(self): 34 nodes, magnitudes = zip(*self.sample_map.items()) 35 return np.random.choice( 36 nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes)) 37 38 39class StatementSampler(NodeSampler): 40 sample_map = dict(( 41 (gast.Assign, 10), 42 (gast.Print, 1), 43 (gast.If, 2), 44 (gast.While, 2), 45 (gast.For, 0), 46 )) 47 48 49class ExpressionSampler(NodeSampler): 50 sample_map = dict(( 51 (gast.UnaryOp, 1), 52 (gast.BinOp, 8), 53 (gast.Name, 1), 54 (gast.Call, 0), 55 )) 56 57 58class CompareSampler(NodeSampler): 59 sample_map = dict(( 60 (gast.Eq, 1), 61 (gast.NotEq, 1), 62 (gast.Lt, 1), 63 (gast.LtE, 1), 64 (gast.Gt, 1), 65 (gast.GtE, 1), 66 (gast.Is, 1), 67 (gast.IsNot, 1), 68 )) 69 70 71class BinaryOpSampler(NodeSampler): 72 sample_map = dict(( 73 (gast.Add, 1), 74 (gast.Sub, 1), 75 (gast.Mult, 1), 76 (gast.Div, 1), 77 (gast.FloorDiv, 1), 78 (gast.Mod, 1), 79 (gast.Pow, 1), 80 )) 81 82 83class UnaryOpSampler(NodeSampler): 84 sample_map = dict(((gast.USub, 1), (gast.UAdd, 0))) 85 86 87class NameSampler(NodeSampler): 88 sample_map = dict(( 89 ('new', 1), 90 ('existing', 1), 91 )) 92 93 94N_CONTROLFLOW_STATEMENTS = 10 95N_FUNCTIONDEF_STATEMENTS = 10 96 97 98class CodeGenerator(object): 99 """Generate random syntactically-valid Python ASTs.""" 100 101 def __init__(self, max_depth=3, depth=0): 102 self.max_depth = max_depth 103 self.depth = depth 104 105 def generate_statement(self): 106 """Generate a statement node, dispatching to the correct class method.""" 107 desired_node = StatementSampler().sample() 108 self.depth += 1 109 110 # Enforce some constraints on generating statements. 111 # E.g., if statements need at least 3 readable variables. 112 # If we fail to satisfy our constraints, draw another sample. 113 if desired_node in (gast.While, gast.For, gast.If): 114 if self.depth > self.max_depth: 115 return self.generate_statement() 116 117 # Go get the generator method and run it 118 method = 'generate_' + desired_node.__name__ 119 visitor = getattr(self, method) 120 node = visitor() 121 self.depth -= 1 122 return node 123 124 def sample_node_list(self, low, high, generator): 125 """Generate a list of statements of random length. 126 127 Args: 128 low: Fewest number of statements to generate. 129 high: Highest number of statements to generate. 130 generator: Function to call to generate nodes. 131 132 Returns: 133 A list of statements. 134 """ 135 statements = [] 136 for _ in range(np.random.randint(low, high)): 137 statements.append(generator()) 138 return statements 139 140 def generate_Name(self, ctx=gast.Load()): 141 variable_name = '_' + ''.join( 142 random.choice(string.ascii_lowercase) for _ in range(4)) 143 return gast.Name(variable_name, ctx=ctx, annotation=None) 144 145 def generate_BinOp(self): 146 # TODO(alexbw): convert to generate_expression when we get to limit 147 # expression depth. 148 op = BinaryOpSampler().sample()() 149 return gast.BinOp(self.generate_Name(), op, self.generate_Name()) 150 151 def generate_Compare(self): 152 op = CompareSampler().sample()() 153 return gast.Compare(self.generate_Name(), [op], [self.generate_Name()]) 154 155 def generate_UnaryOp(self): 156 operand = self.generate_Name() 157 op = UnaryOpSampler().sample()() 158 return gast.UnaryOp(op, operand) 159 160 def generate_expression(self): 161 desired_node = ExpressionSampler().sample() 162 # Go get the generator method and run it 163 method = 'generate_' + desired_node.__name__ 164 generator = getattr(self, method) 165 return generator() 166 167 def generate_Assign(self): 168 """Generate an Assign node.""" 169 # Generate left-hand side 170 target_node = self.generate_Name(gast.Store()) 171 # Generate right-hand side 172 value_node = self.generate_expression() 173 # Put it all together 174 node = gast.Assign(targets=[target_node], value=value_node) 175 return node 176 177 def generate_If(self): 178 """Generate an If node.""" 179 test = self.generate_Compare() 180 181 # Generate true branch statements 182 body = self.sample_node_list( 183 low=1, 184 high=N_CONTROLFLOW_STATEMENTS // 2, 185 generator=self.generate_statement) 186 187 # Generate false branch statements 188 orelse = self.sample_node_list( 189 low=1, 190 high=N_CONTROLFLOW_STATEMENTS // 2, 191 generator=self.generate_statement) 192 193 node = gast.If(test, body, orelse) 194 return node 195 196 def generate_While(self): 197 """Generate a While node.""" 198 199 test = self.generate_Compare() 200 body = self.sample_node_list( 201 low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement) 202 orelse = [] # not generating else statements 203 204 node = gast.While(test, body, orelse) 205 return node 206 207 def generate_Call(self): 208 raise NotImplementedError 209 210 def generate_Return(self): 211 return gast.Return(self.generate_expression()) 212 213 def generate_Print(self): 214 return templates.replace('print(x)', x=self.generate_expression())[0] 215 216 def generate_FunctionDef(self): 217 """Generate a FunctionDef node.""" 218 219 # Generate the arguments, register them as available 220 arg_vars = self.sample_node_list( 221 low=2, high=10, generator=lambda: self.generate_Name(gast.Param())) 222 args = gast.arguments(arg_vars, None, [], [], None, []) 223 224 # Generate the function body 225 body = self.sample_node_list( 226 low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement) 227 body.append(self.generate_Return()) 228 fn_name = self.generate_Name().id 229 node = gast.FunctionDef(fn_name, args, body, (), None) 230 return node 231 232 233def generate_random_functiondef(): 234 return CodeGenerator().generate_FunctionDef() 235