1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ 18 19 #include <memory> 20 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/platform/macros.h" 23 #include "tensorflow/core/protobuf/tensorflow_server.pb.h" 24 25 namespace tensorflow { 26 27 class CoordinationServiceAgent; 28 class DeviceMgr; 29 class EagerContext; 30 class WorkerEnv; 31 class MasterEnv; 32 33 // This library supports a registration/factory-based mechanism for 34 // creating TensorFlow server objects. Each server implementation must 35 // have an accompanying implementation of ServerFactory, and create a 36 // static "registrar" object that calls `ServerFactory::Register()` 37 // with an instance of the factory class. See "rpc/grpc_server_lib.cc" 38 // for an example. 39 40 // Represents a single TensorFlow server that exports Master and Worker 41 // services. 42 class ServerInterface { 43 public: ServerInterface()44 ServerInterface() {} ~ServerInterface()45 virtual ~ServerInterface() {} 46 47 // Starts the server running asynchronously. Returns OK on success, otherwise 48 // returns an error. 49 virtual Status Start() = 0; 50 51 // Stops the server asynchronously. Returns OK on success, otherwise returns 52 // an error. 53 // 54 // After calling `Stop()`, the caller may call `Join()` to block until the 55 // server has stopped. 56 virtual Status Stop() = 0; 57 58 // Blocks until the server has stopped. Returns OK on success, otherwise 59 // returns an error. 60 virtual Status Join() = 0; 61 62 // Returns a target string that can be used to connect to this server using 63 // `tensorflow::NewSession()`. 64 virtual const string target() const = 0; 65 66 virtual WorkerEnv* worker_env() = 0; 67 virtual MasterEnv* master_env() = 0; 68 69 // Update the set of workers that can be reached by the server 70 virtual Status UpdateServerDef(const ServerDef& server_def) = 0; 71 72 // Functions to operate on service-specific properties. 73 // 74 // Add master eager context to local eager service in order to handle enqueue 75 // requests from remote workers. 76 virtual Status AddMasterEagerContextToEagerService( 77 const tensorflow::uint64 context_id, EagerContext* context) = 0; 78 // Set coordination service agent instance to coordination service RPC handler 79 virtual Status SetCoordinationServiceAgentInstance( 80 CoordinationServiceAgent* agent) = 0; 81 82 private: 83 TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface); 84 }; 85 86 class ServerFactory { 87 public: 88 struct Options { 89 // Local DeviceMgr to use. 90 tensorflow::DeviceMgr* local_device_mgr; 91 }; 92 // Creates a new server based on the given `server_def`, and stores 93 // it in `*out_server`. Returns OK on success, otherwise returns an 94 // error. 95 virtual Status NewServer(const ServerDef& server_def, const Options& options, 96 std::unique_ptr<ServerInterface>* out_server) = 0; 97 98 // Returns true if and only if this factory can create a server 99 // based on the given `server_def`. 100 virtual bool AcceptsOptions(const ServerDef& server_def) = 0; 101 ~ServerFactory()102 virtual ~ServerFactory() {} 103 104 // For each `ServerFactory` subclass, an instance of that class must 105 // be registered by calling this method. 106 // 107 // The `server_type` must be unique to the server factory. 108 static void Register(const string& server_type, ServerFactory* factory); 109 110 // Looks up a factory that can create a server based on the given 111 // `server_def`, and stores it in `*out_factory`. Returns OK on 112 // success, otherwise returns an error. 113 static Status GetFactory(const ServerDef& server_def, 114 ServerFactory** out_factory); 115 }; 116 117 // Creates a server based on the given `server_def`, and stores it in 118 // `*out_server`. Returns OK on success, otherwise returns an error. 119 Status NewServer(const ServerDef& server_def, 120 std::unique_ptr<ServerInterface>* out_server); 121 Status NewServerWithOptions(const ServerDef& server_def, 122 const ServerFactory::Options& options, 123 std::unique_ptr<ServerInterface>* out_server); 124 125 } // namespace tensorflow 126 127 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ 128