1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/distributed_runtime/call_options.h" 24 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 26 #include "tensorflow/core/framework/graph.pb.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/protobuf/config.pb.h" 35 #include "tensorflow/core/protobuf/master.pb.h" 36 #include "tensorflow/core/public/session.h" 37 #include "tensorflow/core/public/session_options.h" 38 39 namespace tensorflow { 40 41 class MasterInterface; 42 43 // A Session instance lets the caller drive a TensorFlow graph 44 // computation on potentially remote sets of devices. This is a thin 45 // wrapper around tensorflow::grpc::MasterService. 46 // 47 // Multiple threads must synchronize their accesses to a single 48 // session. 49 class GrpcSession : public Session { 50 protected: 51 explicit GrpcSession(const SessionOptions& options); 52 53 public: 54 static Status Create(const SessionOptions& options, 55 std::unique_ptr<GrpcSession>* out_session); 56 // Resets the resource containers. 57 static Status Reset(const SessionOptions& options, 58 const std::vector<string>& containers); 59 60 ~GrpcSession() override; 61 62 // Creates a session with the "target". The session carries out 63 // the graph computation defined by "graph", and will have version 64 // number "initial_version". 65 Status Create(const GraphDef& graph) override; 66 Status Create(const RunOptions& run_options, const GraphDef& graph) override; 67 Status Create(GraphDef&& graph) override; 68 Status Create(const RunOptions& run_options, GraphDef&& graph) override; 69 70 // Runs with and without RunOptions. 71 Status Run(const std::vector<std::pair<string, Tensor> >& inputs, 72 const std::vector<string>& output_tensor_names, 73 const std::vector<string>& target_node_names, 74 std::vector<Tensor>* outputs) override; 75 Status Run(const RunOptions& run_options, 76 const std::vector<std::pair<string, Tensor> >& inputs, 77 const std::vector<string>& output_tensor_names, 78 const std::vector<string>& target_node_names, 79 std::vector<Tensor>* outputs, RunMetadata* run_metadata) override; 80 81 Status Extend(const GraphDef& graph) override; 82 Status Extend(const RunOptions& run_options, const GraphDef& graph) override; 83 Status Extend(GraphDef&& graph) override; 84 Status Extend(const RunOptions& run_options, GraphDef&& graph) override; 85 86 Status Close() override; 87 88 // NOTE: This API is still experimental and may change. 89 Status PRunSetup(const std::vector<string>& input_names, 90 const std::vector<string>& output_names, 91 const std::vector<string>& target_nodes, 92 string* handle) override; 93 94 // NOTE: This API is still experimental and may change. 95 Status PRun(const string& handle, 96 const std::vector<std::pair<string, Tensor> >& inputs, 97 const std::vector<string>& output_names, 98 std::vector<Tensor>* outputs) override; 99 100 Status ListDevices(std::vector<DeviceAttributes>* response) override; 101 102 Status MakeCallable(const CallableOptions& callable_options, 103 CallableHandle* out_handle) override; 104 Status RunCallable(CallableHandle handle, 105 const std::vector<Tensor>& feed_tensors, 106 std::vector<Tensor>* fetch_tensors, 107 RunMetadata* run_metadata) override; 108 Status ReleaseCallable(CallableHandle handle) override; 109 110 protected: 111 // Takes ownership of `*master`. 112 void SetRemoteMaster(std::unique_ptr<MasterInterface> master); 113 // Allows subclasses to customize Session creation. 114 void SetHandleAndGraphVersion(string handle, int64_t graph_version) 115 TF_LOCKS_EXCLUDED(mu_); 116 117 private: 118 const SessionOptions options_; 119 std::unique_ptr<MasterInterface> master_; 120 mutex mu_; 121 122 // handle_ returned by the master to identify this session. 123 string handle_ TF_GUARDED_BY(mu_); 124 125 // The current version of the graph. 126 int64_t current_graph_version_ TF_GUARDED_BY(mu_); 127 128 bool is_local_ = false; 129 130 Status Handle(string* out_handle) TF_LOCKS_EXCLUDED(mu_); 131 132 Status RunHelper(const RunOptions& run_options, 133 const std::vector<std::pair<string, Tensor> >& inputs, 134 const std::vector<string>& output_tensor_names, 135 const std::vector<string>& target_node_names, 136 std::vector<Tensor>* outputs, RunMetadata* run_metadata, 137 const string& prun_handle); 138 139 Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req, 140 MutableRunStepResponseWrapper* resp); 141 142 // Implementations for all the public interfaces. 143 Status CreateImpl(CallOptions* call_options, GraphDef graph); 144 Status ExtendImpl(CallOptions* call_options, GraphDef graph); 145 146 TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession); 147 }; 148 149 } // namespace tensorflow 150 151 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ 152