1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ 17 #define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "tensorflow/cc/framework/ops.h" 25 #include "tensorflow/cc/framework/scope.h" 26 #include "tensorflow/core/public/session_options.h" 27 28 namespace tensorflow { 29 30 namespace thread { 31 32 struct ThreadPoolOptions; 33 34 } 35 36 /// @addtogroup core 37 /// @{ 38 39 /// A `ClientSession` object lets the caller drive the evaluation of the 40 /// TensorFlow graph constructed with the C++ API. 41 /// 42 /// Example: 43 /// 44 /// Scope root = Scope::NewRootScope(); 45 /// auto a = Placeholder(root, DT_INT32); 46 /// auto c = Add(root, a, {41}); 47 /// 48 /// ClientSession session(root); 49 /// std::vector<Tensor> outputs; 50 /// 51 /// Status s = session.Run({ {a, {1}} }, {c}, &outputs); 52 /// if (!s.ok()) { ... } 53 class ClientSession { 54 public: 55 /// A data type to represent feeds to a Run call. 56 /// 57 /// This is a map of `Output` objects returned by op-constructors to the value 58 /// to feed them with. See `Input::Initializer` for details on what can be 59 /// used as feed values. 60 typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType; 61 62 /// Create a new session to evaluate the graph contained in `scope` by 63 /// connecting to the TensorFlow runtime specified by `target`. 64 ClientSession(const Scope& scope, const string& target); 65 66 /// Same as above, but use the empty string ("") as the target specification. 67 explicit ClientSession(const Scope& scope); 68 69 /// Create a new session, configuring it with `session_options`. 70 ClientSession(const Scope& scope, const SessionOptions& session_options); 71 72 ~ClientSession(); 73 74 /// Evaluate the tensors in `fetch_outputs`. The values are returned as 75 /// `Tensor` objects in `outputs`. The number and order of `outputs` will 76 /// match `fetch_outputs`. 77 Status Run(const std::vector<Output>& fetch_outputs, 78 std::vector<Tensor>* outputs) const; 79 80 /// Same as above, but use the mapping in `inputs` as feeds. 81 Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs, 82 std::vector<Tensor>* outputs) const; 83 84 /// Same as above. Additionally runs the operations ins `run_outputs`. 85 Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs, 86 const std::vector<Operation>& run_outputs, 87 std::vector<Tensor>* outputs) const; 88 89 /// Use `run_options` to turn on performance profiling. `run_metadata`, if not 90 /// null, is filled in with the profiling results. 91 Status Run(const RunOptions& run_options, const FeedType& inputs, 92 const std::vector<Output>& fetch_outputs, 93 const std::vector<Operation>& run_outputs, 94 std::vector<Tensor>* outputs, RunMetadata* run_metadata) const; 95 96 /// Same as above. Additionally allows user to provide custom threadpool 97 /// implementation via ThreadPoolOptions. 98 Status Run(const RunOptions& run_options, const FeedType& inputs, 99 const std::vector<Output>& fetch_outputs, 100 const std::vector<Operation>& run_outputs, 101 std::vector<Tensor>* outputs, RunMetadata* run_metadata, 102 const thread::ThreadPoolOptions& threadpool_options) const; 103 104 /// \brief A handle to a subgraph, created with 105 /// `ClientSession::MakeCallable()`. 106 typedef int64_t CallableHandle; 107 108 /// \brief Creates a `handle` for invoking the subgraph defined by 109 /// `callable_options`. 110 /// NOTE: This API is still experimental and may change. 111 Status MakeCallable(const CallableOptions& callable_options, 112 CallableHandle* out_handle); 113 114 /// \brief Invokes the subgraph named by `handle` with the given options and 115 /// input tensors. 116 /// 117 /// The order of tensors in `feed_tensors` must match the order of names in 118 /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will 119 /// match the order of names in `CallableOptions::fetch()` when this subgraph 120 /// was created. 121 /// NOTE: This API is still experimental and may change. 122 Status RunCallable(CallableHandle handle, 123 const std::vector<Tensor>& feed_tensors, 124 std::vector<Tensor>* fetch_tensors, 125 RunMetadata* run_metadata); 126 127 /// \brief Invokes the subgraph named by `handle` with the given options and 128 /// input tensors. 129 /// 130 /// The order of tensors in `feed_tensors` must match the order of names in 131 /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will 132 /// match the order of names in `CallableOptions::fetch()` when this subgraph 133 /// was created. 134 /// NOTE: This API is still experimental and may change. 135 Status RunCallable(CallableHandle handle, 136 const std::vector<Tensor>& feed_tensors, 137 std::vector<Tensor>* fetch_tensors, 138 RunMetadata* run_metadata, 139 const thread::ThreadPoolOptions& options); 140 141 /// \brief Releases resources associated with the given `handle` in this 142 /// session. 143 /// NOTE: This API is still experimental and may change. 144 Status ReleaseCallable(CallableHandle handle); 145 146 private: 147 class Impl; 148 std::unique_ptr<Impl> impl_; impl()149 Impl* impl() { return impl_.get(); } impl()150 const Impl* impl() const { return impl_.get(); } 151 }; 152 153 /// @} 154 155 } // end namespace tensorflow 156 157 #endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ 158