• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Random code generation for testing/fuzzing."""
16# pylint: disable=invalid-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import random
22import string
23
24import gast
25import numpy as np
26
27from tensorflow.python.autograph.pyct import templates
28
29
30class NodeSampler(object):
31  sample_map = None
32
33  def sample(self):
34    nodes, magnitudes = zip(*self.sample_map.items())
35    return np.random.choice(
36        nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes))
37
38
39class StatementSampler(NodeSampler):
40  sample_map = dict((
41      (gast.Assign, 10),
42      (gast.Print, 1),
43      (gast.If, 2),
44      (gast.While, 2),
45      (gast.For, 0),
46  ))
47
48
49class ExpressionSampler(NodeSampler):
50  sample_map = dict((
51      (gast.UnaryOp, 1),
52      (gast.BinOp, 8),
53      (gast.Name, 1),
54      (gast.Call, 0),
55  ))
56
57
58class CompareSampler(NodeSampler):
59  sample_map = dict((
60      (gast.Eq, 1),
61      (gast.NotEq, 1),
62      (gast.Lt, 1),
63      (gast.LtE, 1),
64      (gast.Gt, 1),
65      (gast.GtE, 1),
66      (gast.Is, 1),
67      (gast.IsNot, 1),
68  ))
69
70
71class BinaryOpSampler(NodeSampler):
72  sample_map = dict((
73      (gast.Add, 1),
74      (gast.Sub, 1),
75      (gast.Mult, 1),
76      (gast.Div, 1),
77      (gast.FloorDiv, 1),
78      (gast.Mod, 1),
79      (gast.Pow, 1),
80  ))
81
82
83class UnaryOpSampler(NodeSampler):
84  sample_map = dict(((gast.USub, 1), (gast.UAdd, 0)))
85
86
87class NameSampler(NodeSampler):
88  sample_map = dict((
89      ('new', 1),
90      ('existing', 1),
91  ))
92
93
94N_CONTROLFLOW_STATEMENTS = 10
95N_FUNCTIONDEF_STATEMENTS = 10
96
97
98class CodeGenerator(object):
99  """Generate random syntactically-valid Python ASTs."""
100
101  def __init__(self, max_depth=3, depth=0):
102    self.max_depth = max_depth
103    self.depth = depth
104
105  def generate_statement(self):
106    """Generate a statement node, dispatching to the correct class method."""
107    desired_node = StatementSampler().sample()
108    self.depth += 1
109
110    # Enforce some constraints on generating statements.
111    # E.g., if statements need at least 3 readable variables.
112    # If we fail to satisfy our constraints, draw another sample.
113    if desired_node in (gast.While, gast.For, gast.If):
114      if self.depth > self.max_depth:
115        return self.generate_statement()
116
117    # Go get the generator method and run it
118    method = 'generate_' + desired_node.__name__
119    visitor = getattr(self, method)
120    node = visitor()
121    self.depth -= 1
122    return node
123
124  def sample_node_list(self, low, high, generator):
125    """Generate a list of statements of random length.
126
127    Args:
128      low: Fewest number of statements to generate.
129      high: Highest number of statements to generate.
130      generator: Function to call to generate nodes.
131
132    Returns:
133      A list of statements.
134    """
135    statements = []
136    for _ in range(np.random.randint(low, high)):
137      statements.append(generator())
138    return statements
139
140  def generate_Name(self, ctx=gast.Load()):
141    variable_name = '_' + ''.join(
142        random.choice(string.ascii_lowercase) for _ in range(4))
143    return gast.Name(variable_name, ctx=ctx, annotation=None)
144
145  def generate_BinOp(self):
146    # TODO(alexbw): convert to generate_expression when we get to limit
147    # expression depth.
148    op = BinaryOpSampler().sample()()
149    return gast.BinOp(self.generate_Name(), op, self.generate_Name())
150
151  def generate_Compare(self):
152    op = CompareSampler().sample()()
153    return gast.Compare(self.generate_Name(), [op], [self.generate_Name()])
154
155  def generate_UnaryOp(self):
156    operand = self.generate_Name()
157    op = UnaryOpSampler().sample()()
158    return gast.UnaryOp(op, operand)
159
160  def generate_expression(self):
161    desired_node = ExpressionSampler().sample()
162    # Go get the generator method and run it
163    method = 'generate_' + desired_node.__name__
164    generator = getattr(self, method)
165    return generator()
166
167  def generate_Assign(self):
168    """Generate an Assign node."""
169    # Generate left-hand side
170    target_node = self.generate_Name(gast.Store())
171    # Generate right-hand side
172    value_node = self.generate_expression()
173    # Put it all together
174    node = gast.Assign(targets=[target_node], value=value_node)
175    return node
176
177  def generate_If(self):
178    """Generate an If node."""
179    test = self.generate_Compare()
180
181    # Generate true branch statements
182    body = self.sample_node_list(
183        low=1,
184        high=N_CONTROLFLOW_STATEMENTS // 2,
185        generator=self.generate_statement)
186
187    # Generate false branch statements
188    orelse = self.sample_node_list(
189        low=1,
190        high=N_CONTROLFLOW_STATEMENTS // 2,
191        generator=self.generate_statement)
192
193    node = gast.If(test, body, orelse)
194    return node
195
196  def generate_While(self):
197    """Generate a While node."""
198
199    test = self.generate_Compare()
200    body = self.sample_node_list(
201        low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement)
202    orelse = []  # not generating else statements
203
204    node = gast.While(test, body, orelse)
205    return node
206
207  def generate_Call(self):
208    raise NotImplementedError
209
210  def generate_Return(self):
211    return gast.Return(self.generate_expression())
212
213  def generate_Print(self):
214    return templates.replace('print(x)', x=self.generate_expression())[0]
215
216  def generate_FunctionDef(self):
217    """Generate a FunctionDef node."""
218
219    # Generate the arguments, register them as available
220    arg_vars = self.sample_node_list(
221        low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
222    args = gast.arguments(arg_vars, None, [], [], None, [])
223
224    # Generate the function body
225    body = self.sample_node_list(
226        low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
227    body.append(self.generate_Return())
228    fn_name = self.generate_Name().id
229    node = gast.FunctionDef(fn_name, args, body, (), None)
230    return node
231
232
233def generate_random_functiondef():
234  return CodeGenerator().generate_FunctionDef()
235