1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Prints a GraphDef to stdout (for testing ExternalDataset).""" 16 17import argparse 18import numpy as np 19import tensorflow.compat.v1 as tf 20 21from fcp.tensorflow import external_dataset 22 23 24def _ParseSingleExample(p): 25 # parse_example doesn't like scalars, so we reshape with [-1]. 26 features = tf.parse_example( 27 tf.reshape(p, [-1]), {"val": tf.FixedLenFeature([], dtype=tf.int64)}) 28 return features["val"] 29 30 31def MakeGraph(): 32 """Makes a GraphDef.""" 33 34 graph = tf.Graph() 35 36 with graph.as_default(): 37 serialized_examples = external_dataset.ExternalDataset( 38 token=tf.placeholder(name="token", dtype=tf.string), 39 selector=tf.placeholder(name="selector", dtype=tf.string)) 40 41 examples = serialized_examples.map(_ParseSingleExample) 42 43 total = examples.reduce(np.int64(0), lambda x, y: x + y) 44 total = tf.identity(total, name="total") 45 46 return graph 47 48 49def _ParseArgs(): 50 parser = argparse.ArgumentParser() 51 parser.add_argument("--output", required=True, type=argparse.FileType("w")) 52 return parser.parse_args() 53 54if __name__ == "__main__": 55 args = _ParseArgs() 56 with args.output: 57 graph_def = MakeGraph().as_graph_def() 58 args.output.write(str(graph_def)) 59