• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_
17 
18 #include <unordered_map>
19 #include "tensorflow/core/distributed_runtime/worker_cache.h"
20 #include "tensorflow/core/distributed_runtime/worker_interface.h"
21 #include "tensorflow/core/util/device_name_utils.h"
22 
23 namespace tensorflow {
24 
25 // Some utilities for testing distributed-mode components in a single process
26 // without RPCs.
27 
28 // Implements the worker interface with methods that just respond with
29 // "unimplemented" status.  Override just the methods needed for
30 // testing.
31 class TestWorkerInterface : public WorkerInterface {
32  public:
GetStatusAsync(const GetStatusRequest * request,GetStatusResponse * response,StatusCallback done)33   void GetStatusAsync(const GetStatusRequest* request,
34                       GetStatusResponse* response,
35                       StatusCallback done) override {
36     done(errors::Unimplemented("GetStatusAsync"));
37   }
38 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)39   void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
40                                 CreateWorkerSessionResponse* response,
41                                 StatusCallback done) override {
42     done(errors::Unimplemented("CreateWorkerSessionAsync"));
43   }
44 
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)45   void DeleteWorkerSessionAsync(CallOptions* opts,
46                                 const DeleteWorkerSessionRequest* request,
47                                 DeleteWorkerSessionResponse* response,
48                                 StatusCallback done) override {
49     done(errors::Unimplemented("DeleteWorkerSessionAsync"));
50   }
51 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)52   void RegisterGraphAsync(const RegisterGraphRequest* request,
53                           RegisterGraphResponse* response,
54                           StatusCallback done) override {
55     done(errors::Unimplemented("RegisterGraphAsync"));
56   }
57 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)58   void DeregisterGraphAsync(const DeregisterGraphRequest* request,
59                             DeregisterGraphResponse* response,
60                             StatusCallback done) override {
61     done(errors::Unimplemented("DeregisterGraphAsync"));
62   }
63 
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * repsonse,StatusCallback done)64   void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
65                      MutableRunGraphResponseWrapper* repsonse,
66                      StatusCallback done) override {
67     done(errors::Unimplemented("RunGraphAsync"));
68   }
69 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)70   void CleanupGraphAsync(const CleanupGraphRequest* request,
71                          CleanupGraphResponse* response,
72                          StatusCallback done) override {
73     done(errors::Unimplemented("RunGraphAsync"));
74   }
75 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)76   void CleanupAllAsync(const CleanupAllRequest* request,
77                        CleanupAllResponse* response,
78                        StatusCallback done) override {
79     done(errors::Unimplemented("RunGraphAsync"));
80   }
81 
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)82   void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request,
83                        TensorResponse* response, StatusCallback done) override {
84     done(errors::Unimplemented("RunGraphAsync"));
85   }
86 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)87   void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
88                     StatusCallback done) override {
89     done(errors::Unimplemented("RunGraphAsync"));
90   }
91 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)92   void TracingAsync(const TracingRequest* request, TracingResponse* response,
93                     StatusCallback done) override {
94     done(errors::Unimplemented("RunGraphAsync"));
95   }
96 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)97   void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
98                     RecvBufResponse* response, StatusCallback done) override {
99     done(errors::Unimplemented("RecvBufAsync"));
100   }
101 
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)102   void CompleteGroupAsync(CallOptions* opts,
103                           const CompleteGroupRequest* request,
104                           CompleteGroupResponse* response,
105                           StatusCallback done) override {
106     done(errors::Unimplemented("RunGraphAsync"));
107   }
108 
CompleteInstanceAsync(CallOptions * ops,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)109   void CompleteInstanceAsync(CallOptions* ops,
110                              const CompleteInstanceRequest* request,
111                              CompleteInstanceResponse* response,
112                              StatusCallback done) override {
113     done(errors::Unimplemented("RunGraphAsync"));
114   }
115 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)116   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
117                             GetStepSequenceResponse* response,
118                             StatusCallback done) override {
119     done(errors::Unimplemented("RunGraphAsync"));
120   }
121 };
122 
123 class TestWorkerCache : public WorkerCacheInterface {
124  public:
~TestWorkerCache()125   virtual ~TestWorkerCache() {}
126 
AddWorker(const string & target,WorkerInterface * wi)127   void AddWorker(const string& target, WorkerInterface* wi) {
128     workers_[target] = wi;
129   }
130 
AddDevice(const string & device_name,const DeviceLocality & dev_loc)131   void AddDevice(const string& device_name, const DeviceLocality& dev_loc) {
132     localities_[device_name] = dev_loc;
133   }
134 
ListWorkers(std::vector<string> * workers)135   void ListWorkers(std::vector<string>* workers) const override {
136     workers->clear();
137     for (auto it : workers_) {
138       workers->push_back(it.first);
139     }
140   }
141 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)142   void ListWorkersInJob(const string& job_name,
143                         std::vector<string>* workers) const override {
144     workers->clear();
145     for (auto it : workers_) {
146       DeviceNameUtils::ParsedName device_name;
147       CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
148       CHECK(device_name.has_job);
149       if (job_name == device_name.job) {
150         workers->push_back(it.first);
151       }
152     }
153   }
154 
CreateWorker(const string & target)155   WorkerInterface* CreateWorker(const string& target) override {
156     auto it = workers_.find(target);
157     if (it != workers_.end()) {
158       return it->second;
159     }
160     return nullptr;
161   }
162 
ReleaseWorker(const string & target,WorkerInterface * worker)163   void ReleaseWorker(const string& target, WorkerInterface* worker) override {}
164 
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)165   bool GetDeviceLocalityNonBlocking(const string& device,
166                                     DeviceLocality* locality) override {
167     auto it = localities_.find(device);
168     if (it != localities_.end()) {
169       *locality = it->second;
170       return true;
171     }
172     return false;
173   }
174 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)175   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
176                               StatusCallback done) override {
177     auto it = localities_.find(device);
178     if (it != localities_.end()) {
179       *locality = it->second;
180       done(Status::OK());
181       return;
182     }
183     done(errors::Internal("Device not found: ", device));
184   }
185 
186  protected:
187   std::unordered_map<string, WorkerInterface*> workers_;
188   std::unordered_map<string, DeviceLocality> localities_;
189 };
190 
191 }  // namespace tensorflow
192 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_
193