• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
17 
18 #include "tensorflow/core/framework/collective.h"
19 #include "tensorflow/core/framework/device_attributes.pb.h"
20 #include "tensorflow/core/lib/gtl/flatmap.h"
21 
22 namespace tensorflow {
23 
24 // Mock objects that can't actually execute a Collective, but satisfy
25 // general infrastructure expectations within tests that don't require
26 // full functionality.
27 
28 class TestCollectiveExecutor : public CollectiveExecutor {
29  public:
TestCollectiveExecutor(CollectiveExecutorMgrInterface * cem)30   explicit TestCollectiveExecutor(CollectiveExecutorMgrInterface* cem)
31       : CollectiveExecutor(cem) {}
32 
RunClosure(std::function<void ()>)33   void RunClosure(std::function<void()>) override {
34     LOG(FATAL) << "Unimplemented";
35   }
36 };
37 
38 class TestParamResolver : public ParamResolverInterface {
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)39   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
40                            CancellationManager* cancel_mgr,
41                            const StatusCallback& done) override {
42     done(errors::Internal("Unimplemented"));
43   }
44 
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)45   void CompleteGroupAsync(const CompleteGroupRequest* request,
46                           CompleteGroupResponse* response,
47                           CancellationManager* cancel_mgr,
48                           const StatusCallback& done) override {
49     done(errors::Internal("Unimplemented"));
50   }
51 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)52   void CompleteInstanceAsync(const CompleteInstanceRequest* request,
53                              CompleteInstanceResponse* response,
54                              CancellationManager* cancel_mgr,
55                              const StatusCallback& done) override {
56     done(errors::Internal("Unimplemented"));
57   }
58 
StartAbort(const Status & s)59   void StartAbort(const Status& s) override { return; }
60 };
61 
62 class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
63  public:
TestCollectiveExecutorMgr()64   TestCollectiveExecutorMgr() {}
65 
~TestCollectiveExecutorMgr()66   ~TestCollectiveExecutorMgr() override {
67     for (auto& iter : table_) {
68       iter.second->Unref();
69     }
70   }
71 
FindOrCreate(int64 step_id)72   CollectiveExecutor* FindOrCreate(int64 step_id) override {
73     mutex_lock l(mu_);
74     CollectiveExecutor* ce = nullptr;
75     auto iter = table_.find(step_id);
76     if (iter != table_.end()) {
77       ce = iter->second;
78     } else {
79       ce = new TestCollectiveExecutor(this);
80       table_[step_id] = ce;
81     }
82     ce->Ref();
83     return ce;
84   }
85 
Cleanup(int64 step_id)86   void Cleanup(int64 step_id) override {
87     mutex_lock l(mu_);
88     auto iter = table_.find(step_id);
89     if (iter != table_.end()) {
90       iter->second->Unref();
91       table_.erase(iter);
92     }
93   }
94 
GetParamResolver()95   ParamResolverInterface* GetParamResolver() const override {
96     return &param_resolver_;
97   }
98 
GetDeviceResolver()99   DeviceResolverInterface* GetDeviceResolver() const override {
100     LOG(FATAL);
101     return nullptr;
102   }
103 
GetNcclCommunicator()104   NcclCommunicatorInterface* GetNcclCommunicator() const override {
105     return nullptr;
106   }
107 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)108   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
109                             GetStepSequenceResponse* response,
110                             const StatusCallback& done) override {
111     done(errors::Internal("unimplemented"));
112   }
113 
RefreshStepIdSequenceAsync(int64 graph_key,const StatusCallback & done)114   void RefreshStepIdSequenceAsync(int64 graph_key,
115                                   const StatusCallback& done) override {
116     done(errors::Internal("unimplemented"));
117   }
118 
NextStepId(int64 graph_key)119   int64 NextStepId(int64 graph_key) override {
120     return CollectiveExecutor::kInvalidId;
121   }
122 
RetireStepId(int64 graph_key,int64 step_id)123   void RetireStepId(int64 graph_key, int64 step_id) override {}
124 
125   mutex mu_;
126   gtl::FlatMap<int64, CollectiveExecutor*> table_ TF_GUARDED_BY(mu_);
127   mutable TestParamResolver param_resolver_;
128 };
129 
130 }  // namespace tensorflow
131 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_
132