• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
17 #define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/cc/framework/ops.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/core/public/session_options.h"
27 
28 namespace tensorflow {
29 
30 namespace thread {
31 
32 struct ThreadPoolOptions;
33 
34 }
35 
36 /// @addtogroup core
37 /// @{
38 
39 /// A `ClientSession` object lets the caller drive the evaluation of the
40 /// TensorFlow graph constructed with the C++ API.
41 ///
42 /// Example:
43 ///
44 ///     Scope root = Scope::NewRootScope();
45 ///     auto a = Placeholder(root, DT_INT32);
46 ///     auto c = Add(root, a, {41});
47 ///
48 ///     ClientSession session(root);
49 ///     std::vector<Tensor> outputs;
50 ///
51 ///     Status s = session.Run({ {a, {1}} }, {c}, &outputs);
52 ///     if (!s.ok()) { ... }
53 class ClientSession {
54  public:
55   /// A data type to represent feeds to a Run call.
56   ///
57   /// This is a map of `Output` objects returned by op-constructors to the value
58   /// to feed them with. See `Input::Initializer` for details on what can be
59   /// used as feed values.
60   typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType;
61 
62   /// Create a new session to evaluate the graph contained in `scope` by
63   /// connecting to the TensorFlow runtime specified by `target`.
64   ClientSession(const Scope& scope, const string& target);
65 
66   /// Same as above, but use the empty string ("") as the target specification.
67   explicit ClientSession(const Scope& scope);
68 
69   /// Create a new session, configuring it with `session_options`.
70   ClientSession(const Scope& scope, const SessionOptions& session_options);
71 
72   ~ClientSession();
73 
74   /// Evaluate the tensors in `fetch_outputs`. The values are returned as
75   /// `Tensor` objects in `outputs`. The number and order of `outputs` will
76   /// match `fetch_outputs`.
77   Status Run(const std::vector<Output>& fetch_outputs,
78              std::vector<Tensor>* outputs) const;
79 
80   /// Same as above, but use the mapping in `inputs` as feeds.
81   Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
82              std::vector<Tensor>* outputs) const;
83 
84   /// Same as above. Additionally runs the operations ins `run_outputs`.
85   Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
86              const std::vector<Operation>& run_outputs,
87              std::vector<Tensor>* outputs) const;
88 
89   /// Use `run_options` to turn on performance profiling. `run_metadata`, if not
90   /// null, is filled in with the profiling results.
91   Status Run(const RunOptions& run_options, const FeedType& inputs,
92              const std::vector<Output>& fetch_outputs,
93              const std::vector<Operation>& run_outputs,
94              std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
95 
96   /// Same as above. Additionally allows user to provide custom threadpool
97   /// implementation via ThreadPoolOptions.
98   Status Run(const RunOptions& run_options, const FeedType& inputs,
99              const std::vector<Output>& fetch_outputs,
100              const std::vector<Operation>& run_outputs,
101              std::vector<Tensor>* outputs, RunMetadata* run_metadata,
102              const thread::ThreadPoolOptions& threadpool_options) const;
103 
104   /// \brief A handle to a subgraph, created with
105   /// `ClientSession::MakeCallable()`.
106   typedef int64 CallableHandle;
107 
108   /// \brief Creates a `handle` for invoking the subgraph defined by
109   /// `callable_options`.
110   /// NOTE: This API is still experimental and may change.
111   Status MakeCallable(const CallableOptions& callable_options,
112                       CallableHandle* out_handle);
113 
114   /// \brief Invokes the subgraph named by `handle` with the given options and
115   /// input tensors.
116   ///
117   /// The order of tensors in `feed_tensors` must match the order of names in
118   /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
119   /// match the order of names in `CallableOptions::fetch()` when this subgraph
120   /// was created.
121   /// NOTE: This API is still experimental and may change.
122   Status RunCallable(CallableHandle handle,
123                      const std::vector<Tensor>& feed_tensors,
124                      std::vector<Tensor>* fetch_tensors,
125                      RunMetadata* run_metadata);
126 
127   /// \brief Invokes the subgraph named by `handle` with the given options and
128   /// input tensors.
129   ///
130   /// The order of tensors in `feed_tensors` must match the order of names in
131   /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
132   /// match the order of names in `CallableOptions::fetch()` when this subgraph
133   /// was created.
134   /// NOTE: This API is still experimental and may change.
135   Status RunCallable(CallableHandle handle,
136                      const std::vector<Tensor>& feed_tensors,
137                      std::vector<Tensor>* fetch_tensors,
138                      RunMetadata* run_metadata,
139                      const thread::ThreadPoolOptions& options);
140 
141   /// \brief Releases resources associated with the given `handle` in this
142   /// session.
143   /// NOTE: This API is still experimental and may change.
144   Status ReleaseCallable(CallableHandle handle);
145 
146  private:
147   class Impl;
148   std::unique_ptr<Impl> impl_;
impl()149   Impl* impl() { return impl_.get(); }
impl()150   const Impl* impl() const { return impl_.get(); }
151 };
152 
153 /// @}
154 
155 }  // end namespace tensorflow
156 
157 #endif  // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
158