1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/client/client_session.h"
17
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/core/public/session_options.h"
27
28 namespace tensorflow {
29
30 class ClientSession::Impl {
31 private:
32 friend class ClientSession;
33
Impl(Session * session,std::shared_ptr<Graph> graph)34 Impl(Session* session, std::shared_ptr<Graph> graph)
35 : session_(session), graph_(std::move(graph)) {}
36
37 static SessionOptions MakeDefaultSessionOptions(const string& target);
38 Status MaybeExtendGraph() const;
39
40 std::unique_ptr<Session> session_;
41 std::shared_ptr<Graph> graph_;
42
43 mutable mutex mu_;
44 mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
45 };
46
ClientSession(const Scope & scope,const string & target)47 ClientSession::ClientSession(const Scope& scope, const string& target)
48 : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {}
49
ClientSession(const Scope & scope)50 ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {}
51
ClientSession(const Scope & scope,const SessionOptions & session_options)52 ClientSession::ClientSession(const Scope& scope,
53 const SessionOptions& session_options) {
54 Session* new_session;
55 Status status = NewSession(session_options, &new_session);
56 TF_CHECK_OK(status) << status;
57 impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr()));
58 CHECK_NOTNULL(impl()->session_.get());
59 }
60
61 // Define destructor here so we can forward declare `Impl` in client_session.h.
62 // If we define a dtor in the header file or use the default dtor,
63 // unique_ptr<Impl> needs the complete type.
~ClientSession()64 ClientSession::~ClientSession() {}
65
MakeDefaultSessionOptions(const string & target)66 SessionOptions ClientSession::Impl::MakeDefaultSessionOptions(
67 const string& target) {
68 SessionOptions options;
69 options.env = Env::Default();
70 options.target = target;
71 return options;
72 }
73
Run(const std::vector<Output> & fetch_outputs,std::vector<Tensor> * outputs) const74 Status ClientSession::Run(const std::vector<Output>& fetch_outputs,
75 std::vector<Tensor>* outputs) const {
76 return Run(FeedType{}, fetch_outputs, {}, outputs);
77 }
78
Run(const FeedType & inputs,const std::vector<Output> & fetch_outputs,std::vector<Tensor> * outputs) const79 Status ClientSession::Run(const FeedType& inputs,
80 const std::vector<Output>& fetch_outputs,
81 std::vector<Tensor>* outputs) const {
82 return Run(inputs, fetch_outputs, {}, outputs);
83 }
84
Run(const FeedType & inputs,const std::vector<Output> & fetch_outputs,const std::vector<Operation> & run_outputs,std::vector<Tensor> * outputs) const85 Status ClientSession::Run(const FeedType& inputs,
86 const std::vector<Output>& fetch_outputs,
87 const std::vector<Operation>& run_outputs,
88 std::vector<Tensor>* outputs) const {
89 return Run(RunOptions(), inputs, fetch_outputs, run_outputs, outputs,
90 nullptr);
91 }
92
MaybeExtendGraph() const93 Status ClientSession::Impl::MaybeExtendGraph() const {
94 mutex_lock l(mu_);
95 int num_nodes = graph_->num_node_ids();
96 if (num_nodes > last_num_graph_nodes_) {
97 GraphDef graph_def;
98 graph_->ToGraphDefSubRange(&graph_def, last_num_graph_nodes_);
99 last_num_graph_nodes_ = num_nodes;
100 return session_->Extend(graph_def);
101 }
102 return OkStatus();
103 }
104
Run(const RunOptions & run_options,const FeedType & inputs,const std::vector<Output> & fetch_outputs,const std::vector<Operation> & run_outputs,std::vector<Tensor> * outputs,RunMetadata * run_metadata) const105 Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
106 const std::vector<Output>& fetch_outputs,
107 const std::vector<Operation>& run_outputs,
108 std::vector<Tensor>* outputs,
109 RunMetadata* run_metadata) const {
110 std::vector<std::pair<string, Tensor>> feeds;
111 for (auto const& feed : inputs) {
112 TF_RETURN_IF_ERROR(feed.second.status);
113 feeds.emplace_back(feed.first.name(), feed.second.tensor);
114 }
115 std::vector<string> output_tensor_names;
116 output_tensor_names.reserve(fetch_outputs.size());
117 for (auto const& output : fetch_outputs) {
118 output_tensor_names.push_back(output.name());
119 }
120 std::vector<string> target_node_names;
121 target_node_names.reserve(run_outputs.size());
122 for (auto const& output : run_outputs) {
123 target_node_names.push_back(output.node()->name());
124 }
125 TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
126 return impl()->session_->Run(run_options, feeds, output_tensor_names,
127 target_node_names, outputs, run_metadata);
128 }
129
Run(const RunOptions & run_options,const FeedType & inputs,const std::vector<Output> & fetch_outputs,const std::vector<Operation> & run_outputs,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options) const130 Status ClientSession::Run(
131 const RunOptions& run_options, const FeedType& inputs,
132 const std::vector<Output>& fetch_outputs,
133 const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
134 RunMetadata* run_metadata,
135 const thread::ThreadPoolOptions& threadpool_options) const {
136 std::vector<std::pair<string, Tensor>> feeds;
137 for (auto const& feed : inputs) {
138 TF_RETURN_IF_ERROR(feed.second.status);
139 feeds.emplace_back(feed.first.name(), feed.second.tensor);
140 }
141 std::vector<string> output_tensor_names;
142 output_tensor_names.reserve(fetch_outputs.size());
143 for (auto const& output : fetch_outputs) {
144 output_tensor_names.push_back(output.name());
145 }
146 std::vector<string> target_node_names;
147 target_node_names.reserve(run_outputs.size());
148 for (auto const& output : run_outputs) {
149 target_node_names.push_back(output.node()->name());
150 }
151 TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
152 return impl()->session_->Run(run_options, feeds, output_tensor_names,
153 target_node_names, outputs, run_metadata,
154 threadpool_options);
155 }
156
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)157 Status ClientSession::MakeCallable(const CallableOptions& callable_options,
158 CallableHandle* out_handle) {
159 TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
160 return impl()->session_->MakeCallable(callable_options, out_handle);
161 }
162
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)163 Status ClientSession::RunCallable(CallableHandle handle,
164 const std::vector<Tensor>& feed_tensors,
165 std::vector<Tensor>* fetch_tensors,
166 RunMetadata* run_metadata) {
167 return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
168 run_metadata);
169 }
170
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & options)171 Status ClientSession::RunCallable(CallableHandle handle,
172 const std::vector<Tensor>& feed_tensors,
173 std::vector<Tensor>* fetch_tensors,
174 RunMetadata* run_metadata,
175 const thread::ThreadPoolOptions& options) {
176 return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
177 run_metadata, options);
178 }
179
ReleaseCallable(CallableHandle handle)180 Status ClientSession::ReleaseCallable(CallableHandle handle) {
181 return impl()->session_->ReleaseCallable(handle);
182 }
183
184 } // end namespace tensorflow
185