1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ 17 18 #include <unordered_map> 19 #include "tensorflow/core/distributed_runtime/worker_cache.h" 20 #include "tensorflow/core/distributed_runtime/worker_interface.h" 21 #include "tensorflow/core/util/device_name_utils.h" 22 23 namespace tensorflow { 24 25 // Some utilities for testing distributed-mode components in a single process 26 // without RPCs. 27 28 // Implements the worker interface with methods that just respond with 29 // "unimplemented" status. Override just the methods needed for 30 // testing. 31 class TestWorkerInterface : public WorkerInterface { 32 public: GetStatusAsync(const GetStatusRequest * request,GetStatusResponse * response,StatusCallback done)33 void GetStatusAsync(const GetStatusRequest* request, 34 GetStatusResponse* response, 35 StatusCallback done) override { 36 done(errors::Unimplemented("GetStatusAsync")); 37 } 38 CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)39 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, 40 CreateWorkerSessionResponse* response, 41 StatusCallback done) override { 42 done(errors::Unimplemented("CreateWorkerSessionAsync")); 43 } 44 DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)45 void DeleteWorkerSessionAsync(CallOptions* opts, 46 const DeleteWorkerSessionRequest* request, 47 DeleteWorkerSessionResponse* response, 48 StatusCallback done) override { 49 done(errors::Unimplemented("DeleteWorkerSessionAsync")); 50 } 51 RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)52 void RegisterGraphAsync(const RegisterGraphRequest* request, 53 RegisterGraphResponse* response, 54 StatusCallback done) override { 55 done(errors::Unimplemented("RegisterGraphAsync")); 56 } 57 DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)58 void DeregisterGraphAsync(const DeregisterGraphRequest* request, 59 DeregisterGraphResponse* response, 60 StatusCallback done) override { 61 done(errors::Unimplemented("DeregisterGraphAsync")); 62 } 63 RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * repsonse,StatusCallback done)64 void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 65 MutableRunGraphResponseWrapper* repsonse, 66 StatusCallback done) override { 67 done(errors::Unimplemented("RunGraphAsync")); 68 } 69 CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)70 void CleanupGraphAsync(const CleanupGraphRequest* request, 71 CleanupGraphResponse* response, 72 StatusCallback done) override { 73 done(errors::Unimplemented("RunGraphAsync")); 74 } 75 CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)76 void CleanupAllAsync(const CleanupAllRequest* request, 77 CleanupAllResponse* response, 78 StatusCallback done) override { 79 done(errors::Unimplemented("RunGraphAsync")); 80 } 81 RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)82 void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, 83 TensorResponse* response, StatusCallback done) override { 84 done(errors::Unimplemented("RunGraphAsync")); 85 } 86 LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)87 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, 88 StatusCallback done) override { 89 done(errors::Unimplemented("RunGraphAsync")); 90 } 91 TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)92 void TracingAsync(const TracingRequest* request, TracingResponse* response, 93 StatusCallback done) override { 94 done(errors::Unimplemented("RunGraphAsync")); 95 } 96 RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)97 void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, 98 RecvBufResponse* response, StatusCallback done) override { 99 done(errors::Unimplemented("RecvBufAsync")); 100 } 101 CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)102 void CompleteGroupAsync(CallOptions* opts, 103 const CompleteGroupRequest* request, 104 CompleteGroupResponse* response, 105 StatusCallback done) override { 106 done(errors::Unimplemented("RunGraphAsync")); 107 } 108 CompleteInstanceAsync(CallOptions * ops,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)109 void CompleteInstanceAsync(CallOptions* ops, 110 const CompleteInstanceRequest* request, 111 CompleteInstanceResponse* response, 112 StatusCallback done) override { 113 done(errors::Unimplemented("RunGraphAsync")); 114 } 115 GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)116 void GetStepSequenceAsync(const GetStepSequenceRequest* request, 117 GetStepSequenceResponse* response, 118 StatusCallback done) override { 119 done(errors::Unimplemented("RunGraphAsync")); 120 } 121 }; 122 123 class TestWorkerCache : public WorkerCacheInterface { 124 public: ~TestWorkerCache()125 virtual ~TestWorkerCache() {} 126 AddWorker(const string & target,WorkerInterface * wi)127 void AddWorker(const string& target, WorkerInterface* wi) { 128 workers_[target] = wi; 129 } 130 AddDevice(const string & device_name,const DeviceLocality & dev_loc)131 void AddDevice(const string& device_name, const DeviceLocality& dev_loc) { 132 localities_[device_name] = dev_loc; 133 } 134 ListWorkers(std::vector<string> * workers)135 void ListWorkers(std::vector<string>* workers) const override { 136 workers->clear(); 137 for (auto it : workers_) { 138 workers->push_back(it.first); 139 } 140 } 141 ListWorkersInJob(const string & job_name,std::vector<string> * workers)142 void ListWorkersInJob(const string& job_name, 143 std::vector<string>* workers) const override { 144 workers->clear(); 145 for (auto it : workers_) { 146 DeviceNameUtils::ParsedName device_name; 147 CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name)); 148 CHECK(device_name.has_job); 149 if (job_name == device_name.job) { 150 workers->push_back(it.first); 151 } 152 } 153 } 154 CreateWorker(const string & target)155 WorkerInterface* CreateWorker(const string& target) override { 156 auto it = workers_.find(target); 157 if (it != workers_.end()) { 158 return it->second; 159 } 160 return nullptr; 161 } 162 ReleaseWorker(const string & target,WorkerInterface * worker)163 void ReleaseWorker(const string& target, WorkerInterface* worker) override {} 164 GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)165 bool GetDeviceLocalityNonBlocking(const string& device, 166 DeviceLocality* locality) override { 167 auto it = localities_.find(device); 168 if (it != localities_.end()) { 169 *locality = it->second; 170 return true; 171 } 172 return false; 173 } 174 GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)175 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, 176 StatusCallback done) override { 177 auto it = localities_.find(device); 178 if (it != localities_.end()) { 179 *locality = it->second; 180 done(Status::OK()); 181 return; 182 } 183 done(errors::Internal("Device not found: ", device)); 184 } 185 186 protected: 187 std::unordered_map<string, WorkerInterface*> workers_; 188 std::unordered_map<string, DeviceLocality> localities_; 189 }; 190 191 } // namespace tensorflow 192 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ 193