• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/client/client_session.h"
17 
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/platform/env.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/core/public/session_options.h"
27 
28 namespace tensorflow {
29 
30 class ClientSession::Impl {
31  private:
32   friend class ClientSession;
33 
Impl(Session * session,std::shared_ptr<Graph> graph)34   Impl(Session* session, std::shared_ptr<Graph> graph)
35       : session_(session), graph_(std::move(graph)) {}
36 
37   static SessionOptions MakeDefaultSessionOptions(const string& target);
38   Status MaybeExtendGraph() const;
39 
40   std::unique_ptr<Session> session_;
41   std::shared_ptr<Graph> graph_;
42 
43   mutable mutex mu_;
44   mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
45 };
46 
ClientSession(const Scope & scope,const string & target)47 ClientSession::ClientSession(const Scope& scope, const string& target)
48     : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {}
49 
ClientSession(const Scope & scope)50 ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {}
51 
ClientSession(const Scope & scope,const SessionOptions & session_options)52 ClientSession::ClientSession(const Scope& scope,
53                              const SessionOptions& session_options) {
54   Session* new_session;
55   Status status = NewSession(session_options, &new_session);
56   TF_CHECK_OK(status) << status;
57   impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr()));
58   CHECK_NOTNULL(impl()->session_.get());
59 }
60 
61 // Define destructor here so we can forward declare `Impl` in client_session.h.
62 // If we define a dtor in the header file or use the default dtor,
63 // unique_ptr<Impl> needs the complete type.
~ClientSession()64 ClientSession::~ClientSession() {}
65 
MakeDefaultSessionOptions(const string & target)66 SessionOptions ClientSession::Impl::MakeDefaultSessionOptions(
67     const string& target) {
68   SessionOptions options;
69   options.env = Env::Default();
70   options.target = target;
71   return options;
72 }
73 
Run(const std::vector<Output> & fetch_outputs,std::vector<Tensor> * outputs) const74 Status ClientSession::Run(const std::vector<Output>& fetch_outputs,
75                           std::vector<Tensor>* outputs) const {
76   return Run(FeedType{}, fetch_outputs, {}, outputs);
77 }
78 
Run(const FeedType & inputs,const std::vector<Output> & fetch_outputs,std::vector<Tensor> * outputs) const79 Status ClientSession::Run(const FeedType& inputs,
80                           const std::vector<Output>& fetch_outputs,
81                           std::vector<Tensor>* outputs) const {
82   return Run(inputs, fetch_outputs, {}, outputs);
83 }
84 
Run(const FeedType & inputs,const std::vector<Output> & fetch_outputs,const std::vector<Operation> & run_outputs,std::vector<Tensor> * outputs) const85 Status ClientSession::Run(const FeedType& inputs,
86                           const std::vector<Output>& fetch_outputs,
87                           const std::vector<Operation>& run_outputs,
88                           std::vector<Tensor>* outputs) const {
89   return Run(RunOptions(), inputs, fetch_outputs, run_outputs, outputs,
90              nullptr);
91 }
92 
MaybeExtendGraph() const93 Status ClientSession::Impl::MaybeExtendGraph() const {
94   mutex_lock l(mu_);
95   int num_nodes = graph_->num_node_ids();
96   if (num_nodes > last_num_graph_nodes_) {
97     GraphDef graph_def;
98     graph_->ToGraphDefSubRange(&graph_def, last_num_graph_nodes_);
99     last_num_graph_nodes_ = num_nodes;
100     return session_->Extend(graph_def);
101   }
102   return OkStatus();
103 }
104 
Run(const RunOptions & run_options,const FeedType & inputs,const std::vector<Output> & fetch_outputs,const std::vector<Operation> & run_outputs,std::vector<Tensor> * outputs,RunMetadata * run_metadata) const105 Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
106                           const std::vector<Output>& fetch_outputs,
107                           const std::vector<Operation>& run_outputs,
108                           std::vector<Tensor>* outputs,
109                           RunMetadata* run_metadata) const {
110   std::vector<std::pair<string, Tensor>> feeds;
111   for (auto const& feed : inputs) {
112     TF_RETURN_IF_ERROR(feed.second.status);
113     feeds.emplace_back(feed.first.name(), feed.second.tensor);
114   }
115   std::vector<string> output_tensor_names;
116   output_tensor_names.reserve(fetch_outputs.size());
117   for (auto const& output : fetch_outputs) {
118     output_tensor_names.push_back(output.name());
119   }
120   std::vector<string> target_node_names;
121   target_node_names.reserve(run_outputs.size());
122   for (auto const& output : run_outputs) {
123     target_node_names.push_back(output.node()->name());
124   }
125   TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
126   return impl()->session_->Run(run_options, feeds, output_tensor_names,
127                                target_node_names, outputs, run_metadata);
128 }
129 
Run(const RunOptions & run_options,const FeedType & inputs,const std::vector<Output> & fetch_outputs,const std::vector<Operation> & run_outputs,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options) const130 Status ClientSession::Run(
131     const RunOptions& run_options, const FeedType& inputs,
132     const std::vector<Output>& fetch_outputs,
133     const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
134     RunMetadata* run_metadata,
135     const thread::ThreadPoolOptions& threadpool_options) const {
136   std::vector<std::pair<string, Tensor>> feeds;
137   for (auto const& feed : inputs) {
138     TF_RETURN_IF_ERROR(feed.second.status);
139     feeds.emplace_back(feed.first.name(), feed.second.tensor);
140   }
141   std::vector<string> output_tensor_names;
142   output_tensor_names.reserve(fetch_outputs.size());
143   for (auto const& output : fetch_outputs) {
144     output_tensor_names.push_back(output.name());
145   }
146   std::vector<string> target_node_names;
147   target_node_names.reserve(run_outputs.size());
148   for (auto const& output : run_outputs) {
149     target_node_names.push_back(output.node()->name());
150   }
151   TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
152   return impl()->session_->Run(run_options, feeds, output_tensor_names,
153                                target_node_names, outputs, run_metadata,
154                                threadpool_options);
155 }
156 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)157 Status ClientSession::MakeCallable(const CallableOptions& callable_options,
158                                    CallableHandle* out_handle) {
159   TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
160   return impl()->session_->MakeCallable(callable_options, out_handle);
161 }
162 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)163 Status ClientSession::RunCallable(CallableHandle handle,
164                                   const std::vector<Tensor>& feed_tensors,
165                                   std::vector<Tensor>* fetch_tensors,
166                                   RunMetadata* run_metadata) {
167   return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
168                                        run_metadata);
169 }
170 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & options)171 Status ClientSession::RunCallable(CallableHandle handle,
172                                   const std::vector<Tensor>& feed_tensors,
173                                   std::vector<Tensor>* fetch_tensors,
174                                   RunMetadata* run_metadata,
175                                   const thread::ThreadPoolOptions& options) {
176   return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
177                                        run_metadata, options);
178 }
179 
ReleaseCallable(CallableHandle handle)180 Status ClientSession::ReleaseCallable(CallableHandle handle) {
181   return impl()->session_->ReleaseCallable(handle);
182 }
183 
184 }  // end namespace tensorflow
185