1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ 17 18 #include "tensorflow/core/framework/collective.h" 19 #include "tensorflow/core/framework/device_attributes.pb.h" 20 #include "tensorflow/core/lib/gtl/flatmap.h" 21 22 namespace tensorflow { 23 24 // Mock objects that can't actually execute a Collective, but satisfy 25 // general infrastructure expectations within tests that don't require 26 // full functionality. 27 28 class TestCollectiveExecutor : public CollectiveExecutor { 29 public: TestCollectiveExecutor(CollectiveExecutorMgrInterface * cem)30 explicit TestCollectiveExecutor(CollectiveExecutorMgrInterface* cem) 31 : CollectiveExecutor(cem) {} 32 RunClosure(std::function<void ()>)33 void RunClosure(std::function<void()>) override { 34 LOG(FATAL) << "Unimplemented"; 35 } 36 }; 37 38 class TestParamResolver : public ParamResolverInterface { CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)39 void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, 40 CancellationManager* cancel_mgr, 41 const StatusCallback& done) override { 42 done(errors::Internal("Unimplemented")); 43 } 44 CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)45 void CompleteGroupAsync(const CompleteGroupRequest* request, 46 CompleteGroupResponse* response, 47 CancellationManager* cancel_mgr, 48 const StatusCallback& done) override { 49 done(errors::Internal("Unimplemented")); 50 } 51 CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)52 void CompleteInstanceAsync(const CompleteInstanceRequest* request, 53 CompleteInstanceResponse* response, 54 CancellationManager* cancel_mgr, 55 const StatusCallback& done) override { 56 done(errors::Internal("Unimplemented")); 57 } 58 StartAbort(const Status & s)59 void StartAbort(const Status& s) override { return; } 60 }; 61 62 class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { 63 public: TestCollectiveExecutorMgr()64 TestCollectiveExecutorMgr() {} 65 ~TestCollectiveExecutorMgr()66 ~TestCollectiveExecutorMgr() override { 67 for (auto& iter : table_) { 68 iter.second->Unref(); 69 } 70 } 71 FindOrCreate(int64 step_id)72 CollectiveExecutor* FindOrCreate(int64 step_id) override { 73 mutex_lock l(mu_); 74 CollectiveExecutor* ce = nullptr; 75 auto iter = table_.find(step_id); 76 if (iter != table_.end()) { 77 ce = iter->second; 78 } else { 79 ce = new TestCollectiveExecutor(this); 80 table_[step_id] = ce; 81 } 82 ce->Ref(); 83 return ce; 84 } 85 Cleanup(int64 step_id)86 void Cleanup(int64 step_id) override { 87 mutex_lock l(mu_); 88 auto iter = table_.find(step_id); 89 if (iter != table_.end()) { 90 iter->second->Unref(); 91 table_.erase(iter); 92 } 93 } 94 GetParamResolver()95 ParamResolverInterface* GetParamResolver() const override { 96 return ¶m_resolver_; 97 } 98 GetDeviceResolver()99 DeviceResolverInterface* GetDeviceResolver() const override { 100 LOG(FATAL); 101 return nullptr; 102 } 103 GetNcclCommunicator()104 NcclCommunicatorInterface* GetNcclCommunicator() const override { 105 return nullptr; 106 } 107 GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)108 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 109 GetStepSequenceResponse* response, 110 const StatusCallback& done) override { 111 done(errors::Internal("unimplemented")); 112 } 113 RefreshStepIdSequenceAsync(int64 graph_key,const StatusCallback & done)114 void RefreshStepIdSequenceAsync(int64 graph_key, 115 const StatusCallback& done) override { 116 done(errors::Internal("unimplemented")); 117 } 118 NextStepId(int64 graph_key)119 int64 NextStepId(int64 graph_key) override { 120 return CollectiveExecutor::kInvalidId; 121 } 122 RetireStepId(int64 graph_key,int64 step_id)123 void RetireStepId(int64 graph_key, int64 step_id) override {} 124 125 mutex mu_; 126 gtl::FlatMap<int64, CollectiveExecutor*> table_ TF_GUARDED_BY(mu_); 127 mutable TestParamResolver param_resolver_; 128 }; 129 130 } // namespace tensorflow 131 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ 132