• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/protobuf/config.pb.h"
35 #include "tensorflow/core/protobuf/master.pb.h"
36 #include "tensorflow/core/public/session.h"
37 #include "tensorflow/core/public/session_options.h"
38 
39 namespace tensorflow {
40 
41 class MasterInterface;
42 
43 // A Session instance lets the caller drive a TensorFlow graph
44 // computation on potentially remote sets of devices. This is a thin
45 // wrapper around tensorflow::grpc::MasterService.
46 //
47 // Multiple threads must synchronize their accesses to a single
48 // session.
49 class GrpcSession : public Session {
50  protected:
51   explicit GrpcSession(const SessionOptions& options);
52 
53  public:
54   static Status Create(const SessionOptions& options,
55                        std::unique_ptr<GrpcSession>* out_session);
56   // Resets the resource containers.
57   static Status Reset(const SessionOptions& options,
58                       const std::vector<string>& containers);
59 
60   ~GrpcSession() override;
61 
62   // Creates a session with the "target". The session carries out
63   // the graph computation defined by "graph", and will have version
64   // number "initial_version".
65   Status Create(const GraphDef& graph) override;
66   Status Create(const RunOptions& run_options, const GraphDef& graph) override;
67   Status Create(GraphDef&& graph) override;
68   Status Create(const RunOptions& run_options, GraphDef&& graph) override;
69 
70   // Runs with and without RunOptions.
71   Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
72              const std::vector<string>& output_tensor_names,
73              const std::vector<string>& target_node_names,
74              std::vector<Tensor>* outputs) override;
75   Status Run(const RunOptions& run_options,
76              const std::vector<std::pair<string, Tensor> >& inputs,
77              const std::vector<string>& output_tensor_names,
78              const std::vector<string>& target_node_names,
79              std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
80 
81   Status Extend(const GraphDef& graph) override;
82   Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
83   Status Extend(GraphDef&& graph) override;
84   Status Extend(const RunOptions& run_options, GraphDef&& graph) override;
85 
86   Status Close() override;
87 
88   // NOTE: This API is still experimental and may change.
89   Status PRunSetup(const std::vector<string>& input_names,
90                    const std::vector<string>& output_names,
91                    const std::vector<string>& target_nodes,
92                    string* handle) override;
93 
94   // NOTE: This API is still experimental and may change.
95   Status PRun(const string& handle,
96               const std::vector<std::pair<string, Tensor> >& inputs,
97               const std::vector<string>& output_names,
98               std::vector<Tensor>* outputs) override;
99 
100   Status ListDevices(std::vector<DeviceAttributes>* response) override;
101 
102   Status MakeCallable(const CallableOptions& callable_options,
103                       CallableHandle* out_handle) override;
104   Status RunCallable(CallableHandle handle,
105                      const std::vector<Tensor>& feed_tensors,
106                      std::vector<Tensor>* fetch_tensors,
107                      RunMetadata* run_metadata) override;
108   Status ReleaseCallable(CallableHandle handle) override;
109 
110  protected:
111   // Takes ownership of `*master`.
112   void SetRemoteMaster(std::unique_ptr<MasterInterface> master);
113   // Allows subclasses to customize Session creation.
114   void SetHandleAndGraphVersion(string handle, int64_t graph_version)
115       TF_LOCKS_EXCLUDED(mu_);
116 
117  private:
118   const SessionOptions options_;
119   std::unique_ptr<MasterInterface> master_;
120   mutex mu_;
121 
122   // handle_ returned by the master to identify this session.
123   string handle_ TF_GUARDED_BY(mu_);
124 
125   // The current version of the graph.
126   int64_t current_graph_version_ TF_GUARDED_BY(mu_);
127 
128   bool is_local_ = false;
129 
130   Status Handle(string* out_handle) TF_LOCKS_EXCLUDED(mu_);
131 
132   Status RunHelper(const RunOptions& run_options,
133                    const std::vector<std::pair<string, Tensor> >& inputs,
134                    const std::vector<string>& output_tensor_names,
135                    const std::vector<string>& target_node_names,
136                    std::vector<Tensor>* outputs, RunMetadata* run_metadata,
137                    const string& prun_handle);
138 
139   Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req,
140                   MutableRunStepResponseWrapper* resp);
141 
142   // Implementations for all the public interfaces.
143   Status CreateImpl(CallOptions* call_options, GraphDef graph);
144   Status ExtendImpl(CallOptions* call_options, GraphDef graph);
145 
146   TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession);
147 };
148 
149 }  // namespace tensorflow
150 
151 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
152