• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
17 
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/framework/collective.h"
20 
21 namespace tensorflow {
22 class CollectiveParamResolverDistributed;
23 class ConfigProto;
24 class DeviceMgr;
25 class DeviceResolverDistributed;
26 class WorkerCacheInterface;
27 class StepSequenceRequest;
28 class StepSequenceResponse;
29 
30 // An implementation of CollectiveExecutorMgr for a distributed environment
31 // that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs.
32 //
33 // In some execution environments it may be possible to implement a
34 // higher-performance solution and use it in place of this class.
35 class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
36  public:
37   RpcCollectiveExecutorMgr(
38       const ConfigProto& config, const DeviceMgr* dev_mgr,
39       std::unique_ptr<DeviceResolverDistributed> dev_resolver,
40       std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
41       WorkerCacheInterface* worker_cache, const string& task_name);
42 
43   virtual ~RpcCollectiveExecutorMgr();
44 
45   // This function should only be called at the group_leader, by an RPC.
46   // Other needs for StepIds should be satisfied by NextStepId.
47   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
48                             GetStepSequenceResponse* response,
49                             const StatusCallback& done) override;
50 
51   void RefreshStepIdSequenceAsync(int64 graph_key,
52                                   const StatusCallback& done) override;
53 
54   int64 NextStepId(int64 graph_key) override;
55 
56   void RetireStepId(int64 graph_key, int64 step_id) override;
57 
58  protected:
59   virtual CollectiveExecutor* Create(int64 step_id) override;
60 
61   WorkerCacheInterface* const worker_cache_;  // Not owned.
62   const string task_name_;
63   string group_leader_;
64   friend class RpcCollectiveExecutorMgrTest;
65 
66  private:
67   Status UpdateStepSequences(const GetStepSequenceResponse& resp);
68 
69   // This class maintains the step_id sequencing for a single
70   // collective_graph_key.
71   struct GraphKeySequence {
GraphKeySequenceGraphKeySequence72     explicit GraphKeySequence(int64 k)
73         : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {}
74 
75     const int64 graph_key_;
76     int64 next_step_id_;
77   };
78 
79   mutex sequence_mu_;
80   gtl::FlatMap<int64, GraphKeySequence*> sequence_table_
81       GUARDED_BY(sequence_mu_);
82 };
83 
84 }  // namespace tensorflow
85 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
86