1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/graph/node_builder.h"
16
17 #include <string>
18
19 #include "tensorflow/core/framework/full_type.pb.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/kernels/ops_util.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace {
29
30 REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)");
31 REGISTER_OP("Sink").Input("i: T").Attr("T: type");
32
TEST(NodeBuilderTest,Simple)33 TEST(NodeBuilderTest, Simple) {
34 Graph graph(OpRegistry::Global());
35 Node* source_node;
36 TF_EXPECT_OK(NodeBuilder("source_op", "Source")
37 .Attr("out_types", {DT_INT32, DT_STRING})
38 .Finalize(&graph, &source_node));
39 ASSERT_TRUE(source_node != nullptr);
40
41 // Try connecting to each of source_node's outputs.
42 TF_EXPECT_OK(NodeBuilder("sink1", "Sink")
43 .Input(source_node)
44 .Finalize(&graph, nullptr));
45 TF_EXPECT_OK(NodeBuilder("sink2", "Sink")
46 .Input(source_node, 1)
47 .Finalize(&graph, nullptr));
48
49 // Generate an error if the index is out of range.
50 EXPECT_FALSE(NodeBuilder("sink3", "Sink")
51 .Input(source_node, 2)
52 .Finalize(&graph, nullptr)
53 .ok());
54 EXPECT_FALSE(NodeBuilder("sink4", "Sink")
55 .Input(source_node, -1)
56 .Finalize(&graph, nullptr)
57 .ok());
58 EXPECT_FALSE(NodeBuilder("sink5", "Sink")
59 .Input({source_node, -1})
60 .Finalize(&graph, nullptr)
61 .ok());
62
63 // Generate an error if the node is nullptr. This can happen when using
64 // GraphDefBuilder if there was an error creating the input node.
65 EXPECT_FALSE(NodeBuilder("sink6", "Sink")
66 .Input(nullptr)
67 .Finalize(&graph, nullptr)
68 .ok());
69 EXPECT_FALSE(NodeBuilder("sink7", "Sink")
70 .Input(NodeBuilder::NodeOut(nullptr, 0))
71 .Finalize(&graph, nullptr)
72 .ok());
73 }
74
75 REGISTER_OP("FullTypeOpBasicType")
76 .Output("o1: out_type")
77 .Attr("out_type: type")
__anonb398ef520202(OpDef* op_def) 78 .SetTypeConstructor([](OpDef* op_def) {
79 FullTypeDef* tdef =
80 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
81 tdef->set_type_id(TFT_ARRAY);
82
83 FullTypeDef* arg = tdef->add_args();
84 arg->set_type_id(TFT_VAR);
85 arg->set_s("out_type");
86
87 return OkStatus();
88 });
89
TEST(NodeBuilderTest,TypeConstructorBasicType)90 TEST(NodeBuilderTest, TypeConstructorBasicType) {
91 Graph graph(OpRegistry::Global());
92 Node* node;
93 TF_EXPECT_OK(NodeBuilder("op", "FullTypeOpBasicType")
94 .Attr("out_type", DT_FLOAT)
95 .Finalize(&graph, &node));
96 ASSERT_TRUE(node->def().has_experimental_type());
97 const FullTypeDef& ft = node->def().experimental_type();
98 ASSERT_EQ(ft.type_id(), TFT_PRODUCT);
99 ASSERT_EQ(ft.args_size(), 1);
100 auto ot = ft.args(0);
101 ASSERT_EQ(ot.type_id(), TFT_ARRAY);
102 ASSERT_EQ(ot.args(0).type_id(), TFT_FLOAT);
103 ASSERT_EQ(ot.args(0).args().size(), 0);
104 }
105
106 REGISTER_OP("FullTypeOpListType")
107 .Output("o1: out_types")
108 .Attr("out_types: list(type)")
__anonb398ef520302(OpDef* op_def) 109 .SetTypeConstructor([](OpDef* op_def) {
110 FullTypeDef* tdef =
111 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
112 tdef->set_type_id(TFT_ARRAY);
113
114 FullTypeDef* arg = tdef->add_args();
115 arg->set_type_id(TFT_VAR);
116 arg->set_s("out_types");
117
118 return OkStatus();
119 });
120
TEST(NodeBuilderTest,TypeConstructorListType)121 TEST(NodeBuilderTest, TypeConstructorListType) {
122 Graph graph(OpRegistry::Global());
123 Node* node;
124 ASSERT_FALSE(NodeBuilder("op", "FullTypeOpListType")
125 .Attr("out_types", {DT_FLOAT, DT_INT32})
126 .Finalize(&graph, &node)
127 .ok());
128 }
129
130 } // namespace
131 } // namespace tensorflow
132