1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 18 19 #include <functional> 20 21 #include "absl/strings/string_view.h" 22 #include "absl/types/optional.h" 23 #include "tensorflow/core/data/dataset.pb.h" 24 #include "tensorflow/core/data/service/worker.pb.h" 25 #include "tensorflow/core/framework/dataset.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/status.h" 28 29 namespace tensorflow { 30 namespace data { 31 32 // The result of a GetElement request. Exactly one of the following will be 33 // true: (1) `components` is nonempty (2) `end_of_sequence` is true (3) `skip` 34 // is true. 35 struct GetElementResult { 36 GetElementResult() = default; 37 GetElementResult(GetElementResult&&) = default; 38 GetElementResult& operator=(GetElementResult&&) = default; 39 40 // A dataset element produced by a GetElement request. 41 std::vector<Tensor> components; 42 // The element's index within the task it came from. 43 int64 element_index; 44 // If true, indicates that there is no more data to read. 45 bool end_of_sequence; 46 // If true, indicates that there is still data, but the caller should skip 47 // reading from the worker. This is used for load balancing when doing round 48 // robin reads. 49 bool skip; 50 51 TF_DISALLOW_COPY_AND_ASSIGN(GetElementResult); 52 }; 53 54 // Client for communicating with the tf.data service transfer server. 55 class DataTransferClient { 56 public: 57 struct Config { 58 absl::string_view protocol; 59 std::string address; 60 }; 61 using FactoryT = 62 std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>; 63 virtual ~DataTransferClient() = default; 64 65 // Fetches the next element. 66 virtual Status GetElement(const GetElementRequest& req, 67 GetElementResult& result) = 0; 68 69 // Makes a best effort to cancel all outstanding calls in progress for the 70 // client, and causes further calls to return Cancelled status. 71 virtual void TryCancel() = 0; 72 73 // Registers a DataTransferClient factory under `name`. 74 static void Register(std::string name, FactoryT factory); 75 76 // Builds a DataTransferClient from the factory registered under `name`. 77 static Status Build(std::string name, Config config, 78 std::unique_ptr<DataTransferClient>* out); 79 }; 80 81 // Server for communicating with the tf.data service transfer client. 82 class DataTransferServer { 83 public: 84 using GetElementT = 85 std::function<Status(const GetElementRequest*, GetElementResult*)>; 86 virtual ~DataTransferServer() = default; 87 88 // Starts DataTransferServer, it should be available for requests afterwards. 89 virtual Status Start() = 0; 90 91 // Return the port that this server is listening on. 92 virtual int get_port() = 0; 93 94 // Register a DataTransferServer factory under `name`. 95 static void Register( 96 std::string name, 97 std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory); 98 99 // Builds a DataTransferServer from the factory registered with `name`. 100 static Status Build(std::string name, GetElementT get_element, 101 std::shared_ptr<DataTransferServer>* out); 102 }; 103 104 } // namespace data 105 } // namespace tensorflow 106 107 #endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 108