• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Generate tensorflow graphs for testing tfcompile."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23import os
24import sys
25
26import six
27from six.moves import range
28
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.client import session
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import function
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import control_flow_util
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.platform import app
42from tensorflow.python.training import saver as saver_lib
43
44FLAGS = None
45
46
47def tfadd(_):
48  x = constant_op.constant([1], name='x_const')
49  y = constant_op.constant([2], name='y_const')
50  math_ops.add(x, y, name='x_y_sum')
51
52
53def tfadd_with_ckpt(out_dir):
54  x = array_ops.placeholder(dtypes.int32, name='x_hold')
55  y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
56  math_ops.add(x, y, name='x_y_sum')
57
58  init_op = variables.global_variables_initializer()
59  saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
60  with session.Session() as sess:
61    sess.run(init_op)
62    sess.run(y.assign(y + 42))
63    # Without the checkpoint, the variable won't be set to 42.
64    ckpt = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt.ckpt')
65    saver.save(sess, ckpt)
66
67
68def tfadd_with_ckpt_saver(out_dir):
69  x = array_ops.placeholder(dtypes.int32, name='x_hold')
70  y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
71  math_ops.add(x, y, name='x_y_sum')
72
73  init_op = variables.global_variables_initializer()
74  saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1)
75  with session.Session() as sess:
76    sess.run(init_op)
77    sess.run(y.assign(y + 42))
78    # Without the checkpoint, the variable won't be set to 42.
79    ckpt_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.ckpt')
80    saver.save(sess, ckpt_file)
81    # Without the SaverDef, the restore op won't be named correctly.
82    saver_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.saver')
83    with open(saver_file, 'wb') as f:
84      f.write(six.ensure_binary(saver.as_saver_def().SerializeToString()))
85
86
87def tfassert_eq(_):
88  x = array_ops.placeholder(dtypes.int32, name='x_hold')
89  y = array_ops.placeholder(dtypes.int32, name='y_hold')
90  control_flow_ops.Assert(
91      math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
92  math_ops.add(x, math_ops.negative(y), name='x_y_diff')
93
94
95def tfcond(_):
96  p = array_ops.placeholder(dtypes.bool, name='p_hold')
97  x = array_ops.placeholder(dtypes.int32, name='x_hold')
98  y = array_ops.placeholder(dtypes.int32, name='y_hold')
99  z = control_flow_ops.cond(p, lambda: x, lambda: y)
100  array_ops.identity(z, name='result')
101
102
103def tfgather(_):
104  params = array_ops.placeholder(dtypes.float32, name='params')
105  indices = array_ops.placeholder(dtypes.int32, name='indices')
106  array_ops.gather(params, indices, name='gather_output')
107
108
109def tfmatmul(_):
110  x = array_ops.placeholder(dtypes.float32, name='x_hold')
111  y = array_ops.placeholder(dtypes.float32, name='y_hold')
112  math_ops.matmul(x, y, name='x_y_prod')
113
114
115def tfmatmulandadd(_):
116  # This tests multiple outputs.
117  x = array_ops.placeholder(dtypes.float32, name='x_hold')
118  y = array_ops.placeholder(dtypes.float32, name='y_hold')
119  math_ops.matmul(x, y, name='x_y_prod')
120  math_ops.add(x, y, name='x_y_sum')
121
122
123def tffunction(_):
124
125  @function.Defun(dtypes.int32, dtypes.int32)
126  def test_func(a, b):
127    return a + b
128
129  x = constant_op.constant([1], name='x_const')
130  y = constant_op.constant([2], name='y_const')
131  test_func(x, y, name='func_call')  # pylint: disable=unexpected-keyword-arg
132
133
134def tfsplits(_):
135  """A more complex graph, including splits."""
136  x = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='x')
137  y = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='y')
138  for _ in range(3):
139    x0, x1 = array_ops.split(x, 2, 0)
140    y0, y1 = array_ops.split(y, 2, 0)
141    x0 += 1
142    y0 += 1
143    z = math_ops.matmul(x, y, name='x_y_prod')
144    a = array_ops.concat([x0, y1], axis=0, name='concat_x0_y1')
145    b = array_ops.concat([y0, x1], axis=0, name='concat_y0_x1')
146    x = math_ops.matmul(a, b, name='a_b')
147    y = math_ops.add(x, z)
148  array_ops.identity(y, name='result')
149
150
151def tftop_k(_):
152  x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
153  output = nn_ops.top_k(x, 2, name='values')
154  array_ops.identity(output[1], name='indices')
155
156
157def tfvariable_readonly(_):
158  x = variables.Variable(1000.0, name='x')
159  old_x = x.value()
160  with ops.control_dependencies([old_x]):
161    new_value = math_ops.add(old_x, 42.0)
162  array_ops.identity(new_value, name='result')
163
164
165# TODO(b/147908587): Change x and the two constants back to have a scalar shape
166#                    when the bug is fixed.
167def tfvariable(_):
168  x = variables.Variable([1000.0], name='x', shape=[1])
169  old_x = x.value()
170  with ops.control_dependencies([old_x]):
171    new_x = x.assign_add([42.0])
172  array_ops.stack([old_x, new_x], name='result')
173
174
175def tfvariable_sequential_updates(_):
176  x = variables.Variable(1.0, name='x')
177  y = variables.Variable(1.0, name='y')
178  updates = control_flow_ops.no_op()
179  for _ in range(3):
180    with ops.control_dependencies([updates]):
181      x_val = x.read_value() + y
182      updates = x.assign_sub(0.1 * x_val)
183
184  array_ops.identity(updates, name='result')
185
186
187def write_graph(build_graph, out_dir):
188  """Build a graph using build_graph and write it out."""
189  g = ops.Graph()
190  with g.as_default():
191    build_graph(out_dir)
192    filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__)
193    with open(filename, 'wb') as f:
194      f.write(six.ensure_binary(g.as_graph_def().SerializeToString()))
195
196
197def main(_):
198  control_flow_util.enable_control_flow_v2()
199  write_graph(tfadd, FLAGS.out_dir)
200  write_graph(tfadd_with_ckpt, FLAGS.out_dir)
201  write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
202  write_graph(tfassert_eq, FLAGS.out_dir)
203  write_graph(tfcond, FLAGS.out_dir)
204  write_graph(tffunction, FLAGS.out_dir)
205  write_graph(tfgather, FLAGS.out_dir)
206  write_graph(tfmatmul, FLAGS.out_dir)
207  write_graph(tfmatmulandadd, FLAGS.out_dir)
208  write_graph(tfsplits, FLAGS.out_dir)
209  write_graph(tftop_k, FLAGS.out_dir)
210  write_graph(tfvariable, FLAGS.out_dir)
211  write_graph(tfvariable_readonly, FLAGS.out_dir)
212  write_graph(tfvariable_sequential_updates, FLAGS.out_dir)
213
214
215if __name__ == '__main__':
216  parser = argparse.ArgumentParser()
217  parser.register('type', 'bool', lambda v: v.lower() == 'true')
218  parser.add_argument(
219      '--out_dir',
220      type=str,
221      default='',
222      help='Output directory for graphs, checkpoints and savers.')
223  FLAGS, unparsed = parser.parse_known_args()
224  app.run(main=main, argv=[sys.argv[0]] + unparsed)
225