• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Prints a GraphDef to stdout (for testing ExternalDataset)."""
16
17import argparse
18import numpy as np
19import tensorflow.compat.v1 as tf
20
21from  fcp.tensorflow import external_dataset
22
23
24def _ParseSingleExample(p):
25  # parse_example doesn't like scalars, so we reshape with [-1].
26  features = tf.parse_example(
27      tf.reshape(p, [-1]), {"val": tf.FixedLenFeature([], dtype=tf.int64)})
28  return features["val"]
29
30
31def MakeGraph():
32  """Makes a GraphDef."""
33
34  graph = tf.Graph()
35
36  with graph.as_default():
37    serialized_examples = external_dataset.ExternalDataset(
38        token=tf.placeholder(name="token", dtype=tf.string),
39        selector=tf.placeholder(name="selector", dtype=tf.string))
40
41    examples = serialized_examples.map(_ParseSingleExample)
42
43    total = examples.reduce(np.int64(0), lambda x, y: x + y)
44    total = tf.identity(total, name="total")
45
46  return graph
47
48
49def _ParseArgs():
50  parser = argparse.ArgumentParser()
51  parser.add_argument("--output", required=True, type=argparse.FileType("w"))
52  return parser.parse_args()
53
54if __name__ == "__main__":
55  args = _ParseArgs()
56  with args.output:
57    graph_def = MakeGraph().as_graph_def()
58    args.output.write(str(graph_def))
59