• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/graph/node_builder.h"
16 
17 #include <string>
18 
19 #include "tensorflow/core/framework/full_type.pb.h"
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/kernels/ops_util.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)");
31 REGISTER_OP("Sink").Input("i: T").Attr("T: type");
32 
TEST(NodeBuilderTest,Simple)33 TEST(NodeBuilderTest, Simple) {
34   Graph graph(OpRegistry::Global());
35   Node* source_node;
36   TF_EXPECT_OK(NodeBuilder("source_op", "Source")
37                    .Attr("out_types", {DT_INT32, DT_STRING})
38                    .Finalize(&graph, &source_node));
39   ASSERT_TRUE(source_node != nullptr);
40 
41   // Try connecting to each of source_node's outputs.
42   TF_EXPECT_OK(NodeBuilder("sink1", "Sink")
43                    .Input(source_node)
44                    .Finalize(&graph, nullptr));
45   TF_EXPECT_OK(NodeBuilder("sink2", "Sink")
46                    .Input(source_node, 1)
47                    .Finalize(&graph, nullptr));
48 
49   // Generate an error if the index is out of range.
50   EXPECT_FALSE(NodeBuilder("sink3", "Sink")
51                    .Input(source_node, 2)
52                    .Finalize(&graph, nullptr)
53                    .ok());
54   EXPECT_FALSE(NodeBuilder("sink4", "Sink")
55                    .Input(source_node, -1)
56                    .Finalize(&graph, nullptr)
57                    .ok());
58   EXPECT_FALSE(NodeBuilder("sink5", "Sink")
59                    .Input({source_node, -1})
60                    .Finalize(&graph, nullptr)
61                    .ok());
62 
63   // Generate an error if the node is nullptr.  This can happen when using
64   // GraphDefBuilder if there was an error creating the input node.
65   EXPECT_FALSE(NodeBuilder("sink6", "Sink")
66                    .Input(nullptr)
67                    .Finalize(&graph, nullptr)
68                    .ok());
69   EXPECT_FALSE(NodeBuilder("sink7", "Sink")
70                    .Input(NodeBuilder::NodeOut(nullptr, 0))
71                    .Finalize(&graph, nullptr)
72                    .ok());
73 }
74 
75 REGISTER_OP("FullTypeOpBasicType")
76     .Output("o1: out_type")
77     .Attr("out_type: type")
__anonb398ef520202(OpDef* op_def) 78     .SetTypeConstructor([](OpDef* op_def) {
79       FullTypeDef* tdef =
80           op_def->mutable_output_arg(0)->mutable_experimental_full_type();
81       tdef->set_type_id(TFT_ARRAY);
82 
83       FullTypeDef* arg = tdef->add_args();
84       arg->set_type_id(TFT_VAR);
85       arg->set_s("out_type");
86 
87       return OkStatus();
88     });
89 
TEST(NodeBuilderTest,TypeConstructorBasicType)90 TEST(NodeBuilderTest, TypeConstructorBasicType) {
91   Graph graph(OpRegistry::Global());
92   Node* node;
93   TF_EXPECT_OK(NodeBuilder("op", "FullTypeOpBasicType")
94                    .Attr("out_type", DT_FLOAT)
95                    .Finalize(&graph, &node));
96   ASSERT_TRUE(node->def().has_experimental_type());
97   const FullTypeDef& ft = node->def().experimental_type();
98   ASSERT_EQ(ft.type_id(), TFT_PRODUCT);
99   ASSERT_EQ(ft.args_size(), 1);
100   auto ot = ft.args(0);
101   ASSERT_EQ(ot.type_id(), TFT_ARRAY);
102   ASSERT_EQ(ot.args(0).type_id(), TFT_FLOAT);
103   ASSERT_EQ(ot.args(0).args().size(), 0);
104 }
105 
106 REGISTER_OP("FullTypeOpListType")
107     .Output("o1: out_types")
108     .Attr("out_types: list(type)")
__anonb398ef520302(OpDef* op_def) 109     .SetTypeConstructor([](OpDef* op_def) {
110       FullTypeDef* tdef =
111           op_def->mutable_output_arg(0)->mutable_experimental_full_type();
112       tdef->set_type_id(TFT_ARRAY);
113 
114       FullTypeDef* arg = tdef->add_args();
115       arg->set_type_id(TFT_VAR);
116       arg->set_s("out_types");
117 
118       return OkStatus();
119     });
120 
TEST(NodeBuilderTest,TypeConstructorListType)121 TEST(NodeBuilderTest, TypeConstructorListType) {
122   Graph graph(OpRegistry::Global());
123   Node* node;
124   ASSERT_FALSE(NodeBuilder("op", "FullTypeOpListType")
125                    .Attr("out_types", {DT_FLOAT, DT_INT32})
126                    .Finalize(&graph, &node)
127                    .ok());
128 }
129 
130 }  // namespace
131 }  // namespace tensorflow
132