• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/protobuf.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/tools/graph_transforms/transform_utils.h"
25 
26 namespace tensorflow {
27 namespace graph_transforms {
28 
29 constexpr char kGraphDefWithPartitionedCall[] =
30     "node {\n"
31     "  name: \"y\"\n"
32     "  op: \"Placeholder\"\n"
33     "}\n"
34     "node {\n"
35     "  name: \"sub/y\"\n"
36     "  op: \"Const\"\n"
37     "}\n"
38     "node {\n"
39     "  name: \"PartitionedCall\"\n"
40     "  op: \"PartitionedCall\"\n"
41     "  input: \"y\"\n"
42     "  input: \"sub/y\"\n"
43     "  attr {\n"
44     "    key: \"f\"\n"
45     "    value {\n"
46     "      func {\n"
47     "        name: \"__inference_simple_add_14\"\n"
48     "      }\n"
49     "    }\n"
50     "  }\n"
51     "}\n"
52     "node {\n"
53     "  name: \"add/y\"\n"
54     "  op: \"Const\"\n"
55     "}\n"
56     "node {\n"
57     "  name: \"add\"\n"
58     "  op: \"AddV2\"\n"
59     "  input: \"PartitionedCall\"\n"
60     "  input: \"add/y\"\n"
61     "}\n"
62     "node {\n"
63     "  name: \"Identity\"\n"
64     "  op: \"Identity\"\n"
65     "  input: \"add\"\n"
66     "}\n"
67     "library {\n"
68     "  function {\n"
69     "    signature {\n"
70     "      name: \"__inference_simple_add_14\"\n"
71     "      input_arg {\n"
72     "        name: \"x\"\n"
73     "        type: DT_FLOAT\n"
74     "      }\n"
75     "      input_arg {\n"
76     "        name: \"y\"\n"
77     "        type: DT_FLOAT\n"
78     "      }\n"
79     "      output_arg {\n"
80     "        name: \"identity\"\n"
81     "        type: DT_FLOAT\n"
82     "      }\n"
83     "    }\n"
84     "    node_def {\n"
85     "      name: \"mul\"\n"
86     "      op: \"Mul\"\n"
87     "      input: \"x\"\n"
88     "      input: \"y\"\n"
89     "    }\n"
90     "    node_def {\n"
91     "      name: \"add/y\"\n"
92     "      op: \"Const\"\n"
93     "    }\n"
94     "    node_def {\n"
95     "      name: \"add\"\n"
96     "      op: \"AddV2\"\n"
97     "      input: \"mul:z:0\"\n"
98     "      input: \"add/y:output:0\"\n"
99     "    }\n"
100     "    node_def {\n"
101     "      name: \"Identity\"\n"
102     "      op: \"Identity\"\n"
103     "      input: \"add:z:0\"\n"
104     "    }\n"
105     "    ret {\n"
106     "      key: \"identity\"\n"
107     "      value: \"Identity:output:0\"\n"
108     "    }\n"
109     "  }\n"
110     "}\n";
111 
112 // Declare here, so we don't need a public header.
113 Status InlinePartitionedCall(const GraphDef& input_graph_def,
114                              const TransformFuncContext& context,
115                              GraphDef* output_graph_def);
116 
TEST(InlinePartitionedCallTest,Inlining)117 TEST(InlinePartitionedCallTest, Inlining) {
118   GraphDef in_graph;
119   EXPECT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
120       kGraphDefWithPartitionedCall, &in_graph));
121 
122   GraphDef result;
123   TransformFuncContext context;
124   context.input_names = {"y"};
125   context.output_names = {"Identity"};
126   TF_ASSERT_OK(InlinePartitionedCall(in_graph, context, &result));
127 
128   EXPECT_TRUE(std::none_of(
129       result.node().cbegin(), result.node().cend(),
130       [](const NodeDef& node) { return node.op() == "PartitionedCall"; }));
131   EXPECT_EQ(9, result.node().size());
132   TF_EXPECT_OK(IsGraphValid(result));
133 }
134 
135 }  // namespace graph_transforms
136 }  // namespace tensorflow
137