1# Lint as: python2, python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Generate tensorflow graphs for testing tfcompile.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import argparse 23import os 24import sys 25 26import six 27from six.moves import range 28 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.client import session 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import function 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import control_flow_util 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import nn_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.platform import app 42from tensorflow.python.training import saver as saver_lib 43 44FLAGS = None 45 46 47def tfadd(_): 48 x = constant_op.constant([1], name='x_const') 49 y = constant_op.constant([2], name='y_const') 50 math_ops.add(x, y, name='x_y_sum') 51 52 53def tfadd_with_ckpt(out_dir): 54 x = array_ops.placeholder(dtypes.int32, name='x_hold') 55 y = variables.VariableV1(constant_op.constant([0]), name='y_saved') 56 math_ops.add(x, y, name='x_y_sum') 57 58 init_op = variables.global_variables_initializer() 59 saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) 60 with session.Session() as sess: 61 sess.run(init_op) 62 sess.run(y.assign(y + 42)) 63 # Without the checkpoint, the variable won't be set to 42. 64 ckpt = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt.ckpt') 65 saver.save(sess, ckpt) 66 67 68def tfadd_with_ckpt_saver(out_dir): 69 x = array_ops.placeholder(dtypes.int32, name='x_hold') 70 y = variables.VariableV1(constant_op.constant([0]), name='y_saved') 71 math_ops.add(x, y, name='x_y_sum') 72 73 init_op = variables.global_variables_initializer() 74 saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) 75 with session.Session() as sess: 76 sess.run(init_op) 77 sess.run(y.assign(y + 42)) 78 # Without the checkpoint, the variable won't be set to 42. 79 ckpt_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.ckpt') 80 saver.save(sess, ckpt_file) 81 # Without the SaverDef, the restore op won't be named correctly. 82 saver_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.saver') 83 with open(saver_file, 'wb') as f: 84 f.write(six.ensure_binary(saver.as_saver_def().SerializeToString())) 85 86 87def tfassert_eq(_): 88 x = array_ops.placeholder(dtypes.int32, name='x_hold') 89 y = array_ops.placeholder(dtypes.int32, name='y_hold') 90 control_flow_ops.Assert( 91 math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') 92 math_ops.add(x, math_ops.negative(y), name='x_y_diff') 93 94 95def tfcond(_): 96 p = array_ops.placeholder(dtypes.bool, name='p_hold') 97 x = array_ops.placeholder(dtypes.int32, name='x_hold') 98 y = array_ops.placeholder(dtypes.int32, name='y_hold') 99 z = control_flow_ops.cond(p, lambda: x, lambda: y) 100 array_ops.identity(z, name='result') 101 102 103def tfgather(_): 104 params = array_ops.placeholder(dtypes.float32, name='params') 105 indices = array_ops.placeholder(dtypes.int32, name='indices') 106 array_ops.gather(params, indices, name='gather_output') 107 108 109def tfmatmul(_): 110 x = array_ops.placeholder(dtypes.float32, name='x_hold') 111 y = array_ops.placeholder(dtypes.float32, name='y_hold') 112 math_ops.matmul(x, y, name='x_y_prod') 113 114 115def tfmatmulandadd(_): 116 # This tests multiple outputs. 117 x = array_ops.placeholder(dtypes.float32, name='x_hold') 118 y = array_ops.placeholder(dtypes.float32, name='y_hold') 119 math_ops.matmul(x, y, name='x_y_prod') 120 math_ops.add(x, y, name='x_y_sum') 121 122 123def tffunction(_): 124 125 @function.Defun(dtypes.int32, dtypes.int32) 126 def test_func(a, b): 127 return a + b 128 129 x = constant_op.constant([1], name='x_const') 130 y = constant_op.constant([2], name='y_const') 131 test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg 132 133 134def tfsplits(_): 135 """A more complex graph, including splits.""" 136 x = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='x') 137 y = array_ops.placeholder(dtypes.float32, shape=[2, 2], name='y') 138 for _ in range(3): 139 x0, x1 = array_ops.split(x, 2, 0) 140 y0, y1 = array_ops.split(y, 2, 0) 141 x0 += 1 142 y0 += 1 143 z = math_ops.matmul(x, y, name='x_y_prod') 144 a = array_ops.concat([x0, y1], axis=0, name='concat_x0_y1') 145 b = array_ops.concat([y0, x1], axis=0, name='concat_y0_x1') 146 x = math_ops.matmul(a, b, name='a_b') 147 y = math_ops.add(x, z) 148 array_ops.identity(y, name='result') 149 150 151def tftop_k(_): 152 x = array_ops.placeholder(dtypes.int32, shape=[5], name='x') 153 output = nn_ops.top_k(x, 2, name='values') 154 array_ops.identity(output[1], name='indices') 155 156 157def tfvariable_readonly(_): 158 x = variables.Variable(1000.0, name='x') 159 old_x = x.value() 160 with ops.control_dependencies([old_x]): 161 new_value = math_ops.add(old_x, 42.0) 162 array_ops.identity(new_value, name='result') 163 164 165# TODO(b/147908587): Change x and the two constants back to have a scalar shape 166# when the bug is fixed. 167def tfvariable(_): 168 x = variables.Variable([1000.0], name='x', shape=[1]) 169 old_x = x.value() 170 with ops.control_dependencies([old_x]): 171 new_x = x.assign_add([42.0]) 172 array_ops.stack([old_x, new_x], name='result') 173 174 175def tfvariable_sequential_updates(_): 176 x = variables.Variable(1.0, name='x') 177 y = variables.Variable(1.0, name='y') 178 updates = control_flow_ops.no_op() 179 for _ in range(3): 180 with ops.control_dependencies([updates]): 181 x_val = x.read_value() + y 182 updates = x.assign_sub(0.1 * x_val) 183 184 array_ops.identity(updates, name='result') 185 186 187def write_graph(build_graph, out_dir): 188 """Build a graph using build_graph and write it out.""" 189 g = ops.Graph() 190 with g.as_default(): 191 build_graph(out_dir) 192 filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__) 193 with open(filename, 'wb') as f: 194 f.write(six.ensure_binary(g.as_graph_def().SerializeToString())) 195 196 197def main(_): 198 control_flow_util.enable_control_flow_v2() 199 write_graph(tfadd, FLAGS.out_dir) 200 write_graph(tfadd_with_ckpt, FLAGS.out_dir) 201 write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) 202 write_graph(tfassert_eq, FLAGS.out_dir) 203 write_graph(tfcond, FLAGS.out_dir) 204 write_graph(tffunction, FLAGS.out_dir) 205 write_graph(tfgather, FLAGS.out_dir) 206 write_graph(tfmatmul, FLAGS.out_dir) 207 write_graph(tfmatmulandadd, FLAGS.out_dir) 208 write_graph(tfsplits, FLAGS.out_dir) 209 write_graph(tftop_k, FLAGS.out_dir) 210 write_graph(tfvariable, FLAGS.out_dir) 211 write_graph(tfvariable_readonly, FLAGS.out_dir) 212 write_graph(tfvariable_sequential_updates, FLAGS.out_dir) 213 214 215if __name__ == '__main__': 216 parser = argparse.ArgumentParser() 217 parser.register('type', 'bool', lambda v: v.lower() == 'true') 218 parser.add_argument( 219 '--out_dir', 220 type=str, 221 default='', 222 help='Output directory for graphs, checkpoints and savers.') 223 FLAGS, unparsed = parser.parse_known_args() 224 app.run(main=main, argv=[sys.argv[0]] + unparsed) 225