1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_def_builder.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/graph/tensor_id.h"
21 #include "tensorflow/core/lib/core/errors.h"
22
23 namespace tensorflow {
24
Options(Graph * graph,Status * status)25 GraphDefBuilder::Options::Options(Graph* graph, Status* status)
26 : graph_(graph), status_(status) {}
~Options()27 GraphDefBuilder::Options::~Options() {}
28
WithName(StringPiece name) const29 GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
30 StringPiece name) const {
31 return Options(*this).WithNameImpl(name);
32 }
WithDevice(StringPiece device) const33 GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
34 StringPiece device) const {
35 return Options(*this).WithDeviceImpl(device);
36 }
WithControlInput(Node * control_input) const37 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
38 Node* control_input) const {
39 return Options(*this).WithControlInputImpl(control_input);
40 }
WithControlInputs(gtl::ArraySlice<Node * > control_inputs) const41 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
42 gtl::ArraySlice<Node*> control_inputs) const {
43 return Options(*this).WithControlInputsImpl(control_inputs);
44 }
WithNameImpl(StringPiece name)45 GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
46 StringPiece name) {
47 name_ = string(name);
48 return *this;
49 }
WithDeviceImpl(StringPiece device)50 GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
51 StringPiece device) {
52 device_ = string(device);
53 return *this;
54 }
WithControlInputImpl(Node * control_input)55 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
56 Node* control_input) {
57 control_inputs_.push_back(control_input);
58 return *this;
59 }
WithControlInputsImpl(gtl::ArraySlice<Node * > control_inputs)60 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
61 gtl::ArraySlice<Node*> control_inputs) {
62 control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
63 control_inputs.end());
64 return *this;
65 }
66
ToGraphDef(GraphDef * graph_def) const67 Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
68 if (status_.ok()) {
69 graph_.ToGraphDef(graph_def);
70 }
71 return status_;
72 }
73
GetNameForOp(StringPiece op) const74 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
75 if (name_.empty()) return graph_->NewName(op);
76 return name_;
77 }
78
FinalizeBuilder(NodeBuilder * builder) const79 Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
80 builder->ControlInputs(control_inputs_);
81 if (!device_.empty()) builder->Device(device_);
82 for (const auto& attr : attrs_) {
83 builder->Attr(attr.first, attr.second);
84 }
85
86 Node* returned_node;
87 UpdateStatus(builder->Finalize(graph_, &returned_node));
88 return returned_node;
89 }
90
UpdateStatus(const Status & status) const91 void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
92 if (status_ == nullptr) {
93 TF_CHECK_OK(status);
94 } else {
95 status_->Update(status);
96 }
97 }
98
99 namespace ops {
100
SourceOp(const string & op_name,const GraphDefBuilder::Options & opts)101 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
102 if (opts.HaveError()) return nullptr;
103 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
104 opts.op_registry());
105 return opts.FinalizeBuilder(&node_builder);
106 }
107
UnaryOp(const string & op_name,NodeOut input,const GraphDefBuilder::Options & opts)108 Node* UnaryOp(const string& op_name, NodeOut input,
109 const GraphDefBuilder::Options& opts) {
110 if (opts.HaveError()) return nullptr;
111 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
112 opts.op_registry());
113 node_builder.Input(std::move(input));
114 return opts.FinalizeBuilder(&node_builder);
115 }
116
BinaryOp(const string & op_name,NodeOut a,NodeOut b,const GraphDefBuilder::Options & opts)117 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
118 const GraphDefBuilder::Options& opts) {
119 if (opts.HaveError()) return nullptr;
120 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
121 opts.op_registry());
122 node_builder.Input(std::move(a)).Input(std::move(b));
123 return opts.FinalizeBuilder(&node_builder);
124 }
125
126 } // end namespace ops
127 } // end namespace tensorflow
128