1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_PUBLIC_SESSION_H_ 17 #define TENSORFLOW_CORE_PUBLIC_SESSION_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/core/framework/device_attributes.pb.h" 23 #include "tensorflow/core/framework/graph.pb.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/env.h" 28 #include "tensorflow/core/protobuf/config.pb.h" 29 #include "tensorflow/core/public/session_options.h" 30 31 namespace tensorflow { 32 class DeviceMgr; 33 34 namespace thread { 35 36 struct ThreadPoolOptions; 37 38 } 39 40 /// \brief A Session instance lets a caller drive a TensorFlow graph 41 /// computation. 42 /// 43 /// When a Session is created with a given target, a new Session object 44 /// is bound to the universe of resources specified by that target. 45 /// Those resources are available to this session to perform 46 /// computation described in the GraphDef. After extending the session 47 /// with a graph, the caller uses the Run() API to perform the 48 /// computation and potentially fetch outputs as Tensors. 49 /// 50 /// Example: 51 /// 52 /// ```c++ 53 /// 54 /// tensorflow::GraphDef graph; 55 /// // ... Create or load graph into "graph". 56 /// 57 /// // This example uses the default options which connects 58 /// // to a local runtime. 59 /// tensorflow::SessionOptions options; 60 /// std::unique_ptr<tensorflow::Session> 61 /// session(tensorflow::NewSession(options)); 62 /// 63 /// // Create the session with this graph. 64 /// tensorflow::Status s = session->Create(graph); 65 /// if (!s.ok()) { ... } 66 /// 67 /// // Run the graph and fetch the first output of the "output" 68 /// // operation, and also run to but do not return anything 69 /// // for the "update_state" operation. 70 /// std::vector<tensorflow::Tensor> outputs; 71 /// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); 72 /// if (!s.ok()) { ... } 73 /// 74 /// // Map the output as a flattened float tensor, and do something 75 /// // with it. 76 /// auto output_tensor = outputs[0].flat<float>(); 77 /// if (output_tensor(0) > 0.5) { ... } 78 /// 79 /// // Close the session to release the resources associated with 80 /// // this session. 81 /// session->Close(); 82 /// 83 /// ``` 84 /// 85 /// A Session allows concurrent calls to Run(), though a Session must 86 /// be created / extended by a single thread. 87 /// 88 /// Only one thread must call Close(), and Close() must only be called 89 /// after all other calls to Run() have returned. 90 class Session { 91 public: 92 Session(); 93 virtual ~Session(); 94 95 /// \brief Create the graph to be used for the session. 96 /// 97 /// Returns an error if this session has already been created with a 98 /// graph. To re-use the session with a different graph, the caller 99 /// must Close() the session first. 100 virtual Status Create(const GraphDef& graph) = 0; 101 #ifndef SWIG Create(GraphDef && graph)102 virtual Status Create(GraphDef&& graph) { return Create(graph); } 103 #endif 104 105 /// \brief Adds operations to the graph that is already registered with the 106 /// Session. 107 /// 108 /// The names of new operations in "graph" must not exist in the 109 /// graph that is already registered. 110 virtual Status Extend(const GraphDef& graph) = 0; 111 #ifndef SWIG Extend(GraphDef && graph)112 virtual Status Extend(GraphDef&& graph) { return Extend(graph); } 113 #endif 114 115 /// \brief Runs the graph with the provided input tensors and fills 116 /// `outputs` for the endpoints specified in `output_tensor_names`. 117 /// Runs to but does not return Tensors for the nodes in 118 /// `target_node_names`. 119 /// 120 /// The order of tensors in `outputs` will match the order provided 121 /// by `output_tensor_names`. 122 /// 123 /// If `Run` returns `OK()`, then `outputs->size()` will be equal to 124 /// `output_tensor_names.size()`. If `Run` does not return `OK()`, the 125 /// state of `outputs` is undefined. 126 /// 127 /// REQUIRES: The name of each Tensor of the input or output must 128 /// match a "Tensor endpoint" in the `GraphDef` passed to `Create()`. 129 /// 130 /// REQUIRES: At least one of `output_tensor_names` and 131 /// `target_node_names` must be non-empty. 132 /// 133 /// REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty. 134 virtual Status Run(const std::vector<std::pair<std::string, Tensor> >& inputs, 135 const std::vector<std::string>& output_tensor_names, 136 const std::vector<std::string>& target_node_names, 137 std::vector<Tensor>* outputs) = 0; 138 139 /// \brief Implementations which support `RunOptions`. 140 // 141 /// NOTE: This API is still experimental and may change. Create(const RunOptions & run_options,const GraphDef & graph)142 virtual Status Create(const RunOptions& run_options, const GraphDef& graph) { 143 return errors::Unimplemented( 144 "Create(const RunOptions& run_options, const GraphDef& graph) is not " 145 "supported for this session."); 146 } Extend(const RunOptions & run_options,const GraphDef & graph)147 virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) { 148 return errors::Unimplemented( 149 "Extend(const RunOptions& run_options, const GraphDef& graph) is not " 150 "supported for this session."); 151 } 152 #ifndef SWIG Create(const RunOptions & run_options,GraphDef && graph)153 virtual Status Create(const RunOptions& run_options, GraphDef&& graph) { 154 return Create(run_options, graph); 155 } Extend(const RunOptions & run_options,GraphDef && graph)156 virtual Status Extend(const RunOptions& run_options, GraphDef&& graph) { 157 return Extend(run_options, graph); 158 } 159 #endif Close(const RunOptions & run_options)160 virtual Status Close(const RunOptions& run_options) { 161 return errors::Unimplemented( 162 "Close(const RunOptions& run_options) is not supported for this " 163 "session."); 164 } 165 166 /// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and 167 /// to retrieve non-Tensor metadata output via a `RunMetadata` proto for this 168 /// step. `run_metadata` may be nullptr, in which case any metadata output is 169 /// discarded. 170 /// NOTE: This API is still experimental and may change. 171 virtual Status Run(const RunOptions& run_options, 172 const std::vector<std::pair<std::string, Tensor> >& inputs, 173 const std::vector<std::string>& output_tensor_names, 174 const std::vector<std::string>& target_node_names, 175 std::vector<Tensor>* outputs, RunMetadata* run_metadata); 176 177 /// \brief Like `Run` with `RunOptions` proto, but allows user to provide 178 /// custom threadpool implementation via ThreadPoolOptions. 179 /// NOTE: This API is still experimental and may change. Run(const RunOptions & run_options,const std::vector<std::pair<std::string,Tensor>> & inputs,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)180 virtual Status Run(const RunOptions& run_options, 181 const std::vector<std::pair<std::string, Tensor> >& inputs, 182 const std::vector<std::string>& output_tensor_names, 183 const std::vector<std::string>& target_node_names, 184 std::vector<Tensor>* outputs, RunMetadata* run_metadata, 185 const thread::ThreadPoolOptions& threadpool_options) { 186 return errors::Unimplemented( 187 "Run with threadpool is not supported for this session."); 188 } 189 190 /// \brief Sets up a graph for partial execution. All future feeds and 191 /// fetches are specified by `input_names` and `output_names`. Returns 192 /// `handle` that can be used to perform a sequence of partial feeds and 193 /// fetches. 194 /// NOTE: This API is still experimental and may change. 195 virtual Status PRunSetup(const std::vector<std::string>& input_names, 196 const std::vector<std::string>& output_names, 197 const std::vector<std::string>& target_nodes, 198 std::string* handle); 199 200 /// \brief Continues the pending execution specified by `handle` with the 201 /// provided input tensors and fills `outputs` for the endpoints specified 202 /// in `output_names`. 203 /// NOTE: This API is still experimental and may change. 204 virtual Status PRun( 205 const std::string& handle, 206 const std::vector<std::pair<std::string, Tensor> >& inputs, 207 const std::vector<std::string>& output_names, 208 std::vector<Tensor>* outputs); 209 210 /// \brief List devices in the session. 211 /// 212 /// Retrieves the list of available devices within the session, and populates 213 /// *response. This API is optional. If it is unimplemented, Status will 214 /// return a corresponding error message, and *response will be unmodified. 215 virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0; 216 217 /// \brief Closes this session. 218 /// 219 /// Closing a session releases the resources used by this session 220 /// on the TensorFlow runtime (specified during session creation by 221 /// the `SessionOptions::target` field). 222 virtual Status Close() = 0; 223 224 // NOTE(ashankar): As of July 2017, this method was added to facilitate some 225 // experimentation. Reconsider/re-evaluate after September 2017. 226 // 227 // Sets `*output` to the `DeviceMgr` that owns accessible devices in the 228 // address-space of the caller. LocalDeviceManager(const DeviceMgr ** output)229 virtual Status LocalDeviceManager(const DeviceMgr** output) { 230 return errors::Unimplemented( 231 "LocalDeviceManager is not supported for this session."); 232 } 233 234 /// \brief A handle to a subgraph, created with `Session::MakeCallable()`. 235 typedef int64 CallableHandle; 236 237 /// \brief Creates a `handle` for invoking the subgraph defined by 238 /// `callable_options`. 239 /// NOTE: This API is still experimental and may change. MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)240 virtual Status MakeCallable(const CallableOptions& callable_options, 241 CallableHandle* out_handle) { 242 return errors::Unimplemented( 243 "MakeCallable is not supported for this session."); 244 } 245 246 /// \brief Invokes the subgraph named by `handle` with the given options and 247 /// input tensors. 248 /// 249 /// The order of tensors in `feed_tensors` must and `fetch_tensors` will 250 /// match the order of names in `CallableOptions::feed()` and 251 /// `CallableOptions::fetch()` when this subgraph was created. 252 /// NOTE: This API is still experimental and may change. RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)253 virtual Status RunCallable(CallableHandle handle, 254 const std::vector<Tensor>& feed_tensors, 255 std::vector<Tensor>* fetch_tensors, 256 RunMetadata* run_metadata) { 257 return errors::Unimplemented( 258 "RunCallable is not supported for this session."); 259 } 260 261 /// \brief Invokes the subgraph named by `handle` with the given options and 262 /// input tensors. User can provide custom threadpool implementation via 263 /// threadpool_options. 264 /// 265 /// The order of tensors in `feed_tensors` must and `fetch_tensors` will 266 /// match the order of names in `CallableOptions::feed()` and 267 /// `CallableOptions::fetch()` when this subgraph was created. 268 /// NOTE: This API is still experimental and may change. RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)269 virtual Status RunCallable( 270 CallableHandle handle, const std::vector<Tensor>& feed_tensors, 271 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata, 272 const thread::ThreadPoolOptions& threadpool_options) { 273 return errors::Unimplemented( 274 "RunCallable with threadpool is not supported for this session."); 275 } 276 277 /// \brief Releases resources associated with the given `handle` in this 278 /// session. 279 /// NOTE: This API is still experimental and may change. ReleaseCallable(CallableHandle handle)280 virtual Status ReleaseCallable(CallableHandle handle) { 281 return errors::Unimplemented( 282 "ReleaseCallable is not supported for this session."); 283 } 284 285 /// \brief Release global graph-related state in this session. 286 /// 287 /// After calling `this->Finalize()`, calls to `this->Run()` with previously 288 /// unseen feeds and fetches, and calls to `this->MakeCallable()` will fail. 289 /// Using `MakeCallable()` and `RunCallable()` is recommended, because 290 /// explicit callable creation makes it clearer where the `Finalize()` call 291 /// should be placed. 292 /// 293 /// This API can be used in conjunction with a "warmup" phase to reduce the 294 /// memory consumed by the session: 295 /// 296 /// 1. Call `Session::Create()`. 297 /// 2. Call `Session::MakeCallable()` for all subgraphs that you will execute 298 /// in the session. 299 /// 3. Call `Session::Finalize()` to release global graph-related state. 300 /// 4. Call `Session::RunCallable()` with the handle(s) created in step 2. 301 /// 302 /// NOTE: This API is still experimental and may change. Finalize()303 virtual Status Finalize() { 304 return errors::Unimplemented("Finalize is not supported for this session."); 305 } 306 }; 307 308 /// \brief Create a new session with the given options. 309 /// 310 /// If session creation succeeds, the new `Session` will be stored in 311 /// `*out_session`, the caller will take ownership of the returned 312 /// `*out_session`, and this function will return `OK()`. Otherwise, this 313 /// function will return an error status and set *out_session to nullptr. 314 Status NewSession(const SessionOptions& options, Session** out_session); 315 316 /// \brief Resets resource containers associated with a target. 317 /// 318 /// Reset() allows misbehaving or slow sessions to be aborted and closed, and 319 /// causes their resources eventually to be released. Reset() does not wait 320 /// for the computations in old sessions to cease; it merely starts the 321 /// process of tearing them down. However, if a new session is started after 322 /// a Reset(), the new session is isolated from changes that old sessions 323 /// (started prior to the Reset()) may continue to make to resources, provided 324 /// all those resources are in containers listed in "containers". 325 /// 326 /// Old sessions may continue to have side-effects on resources not in 327 /// containers listed in "containers", and thus may affect future 328 /// sessions' results in ways that are hard to predict. Thus, if well-defined 329 /// behavior is desired, it is recommended that all containers be listed in 330 /// "containers". 331 /// 332 /// `containers` is a vector of string representation of resource container 333 /// names. When a resource container is reset, the resources held by the 334 /// container will be released. In particular, all Variables in the container 335 /// will become undefined. If the "containers" vector is empty, the default 336 /// container is assumed. If the "containers" vector is non-empty, the 337 /// default container should be listed explicitly. 338 /// 339 /// If Reset succeeds, this function will return `OK()`. Otherwise, this 340 /// function will return an error status. 341 Status Reset(const SessionOptions& options, 342 const std::vector<std::string>& containers); 343 344 /// \brief Create a new session with the given options. 345 /// 346 /// If a new `Session` object could not be created, this function will 347 /// return nullptr. 348 /// 349 /// *Strongly prefer* the version of NewSession that returns Status, 350 /// which contains more helpful error information. 351 Session* NewSession(const SessionOptions& options); 352 353 } // end namespace tensorflow 354 355 #endif // TENSORFLOW_CORE_PUBLIC_SESSION_H_ 356