1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ 17 18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h" 19 #include "tensorflow/core/framework/collective.h" 20 21 namespace tensorflow { 22 class CollectiveParamResolverDistributed; 23 class ConfigProto; 24 class DeviceMgr; 25 class DeviceResolverDistributed; 26 class WorkerCacheInterface; 27 class StepSequenceRequest; 28 class StepSequenceResponse; 29 30 // An implementation of CollectiveExecutorMgr for a distributed environment 31 // that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs. 32 // 33 // In some execution environments it may be possible to implement a 34 // higher-performance solution and use it in place of this class. 35 class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { 36 public: 37 RpcCollectiveExecutorMgr( 38 const ConfigProto& config, const DeviceMgr* dev_mgr, 39 std::unique_ptr<DeviceResolverDistributed> dev_resolver, 40 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver, 41 WorkerCacheInterface* worker_cache, const string& task_name); 42 43 virtual ~RpcCollectiveExecutorMgr(); 44 45 // This function should only be called at the group_leader, by an RPC. 46 // Other needs for StepIds should be satisfied by NextStepId. 47 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 48 GetStepSequenceResponse* response, 49 const StatusCallback& done) override; 50 51 void RefreshStepIdSequenceAsync(int64 graph_key, 52 const StatusCallback& done) override; 53 54 int64 NextStepId(int64 graph_key) override; 55 56 void RetireStepId(int64 graph_key, int64 step_id) override; 57 58 protected: 59 virtual CollectiveExecutor* Create(int64 step_id) override; 60 61 WorkerCacheInterface* const worker_cache_; // Not owned. 62 const string task_name_; 63 string group_leader_; 64 friend class RpcCollectiveExecutorMgrTest; 65 66 private: 67 Status UpdateStepSequences(const GetStepSequenceResponse& resp); 68 69 // This class maintains the step_id sequencing for a single 70 // collective_graph_key. 71 struct GraphKeySequence { GraphKeySequenceGraphKeySequence72 explicit GraphKeySequence(int64 k) 73 : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {} 74 75 const int64 graph_key_; 76 int64 next_step_id_; 77 }; 78 79 mutex sequence_mu_; 80 gtl::FlatMap<int64, GraphKeySequence*> sequence_table_ 81 GUARDED_BY(sequence_mu_); 82 }; 83 84 } // namespace tensorflow 85 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ 86