1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "absl/strings/string_view.h" 25 #include "absl/types/optional.h" 26 #include "tensorflow/core/data/dataset.pb.h" 27 #include "tensorflow/core/data/service/worker.pb.h" 28 #include "tensorflow/core/framework/dataset.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/status.h" 32 33 namespace tensorflow { 34 namespace data { 35 36 // The result of a GetElement request. Exactly one of the following will be 37 // true: (1) `components` is nonempty (2) `end_of_sequence` is true (3) `skip` 38 // is true. 39 struct GetElementResult { 40 GetElementResult() = default; 41 GetElementResult(const GetElementResult&) = delete; 42 GetElementResult& operator=(const GetElementResult&) = delete; 43 GetElementResult(GetElementResult&&) = default; 44 GetElementResult& operator=(GetElementResult&&) = default; 45 46 // Creates a copy of this result. This is used to create multiple copies of 47 // the same cached value. 48 GetElementResult Copy() const; 49 50 // Estimated memory used by this object, measured in bytes. 51 size_t EstimatedMemoryUsageBytes() const; 52 53 // A dataset element produced by a GetElement request. 54 std::vector<Tensor> components; 55 // The element's index within the task it came from. 56 int64_t element_index = 0; 57 // If true, indicates that there is no more data to read. 58 bool end_of_sequence = false; 59 // If true, indicates that there is still data, but the caller should skip 60 // reading from the worker. This is used for load balancing when doing round 61 // robin reads. 62 bool skip = false; 63 }; 64 65 // Client for communicating with the tf.data service transfer server. 66 class DataTransferClient { 67 public: 68 struct Config { 69 absl::string_view protocol; 70 std::string address; 71 }; 72 using FactoryT = 73 std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>; 74 virtual ~DataTransferClient() = default; 75 76 // Fetches the next element. 77 virtual Status GetElement(const GetElementRequest& req, 78 GetElementResult& result) = 0; 79 80 // Makes a best effort to cancel all outstanding calls in progress for the 81 // client, and causes further calls to return Cancelled status. 82 virtual void TryCancel() = 0; 83 84 // Registers a DataTransferClient factory under `name`. 85 static void Register(std::string name, FactoryT factory); 86 87 // Builds a DataTransferClient from the factory registered under `name`. 88 static Status Build(std::string name, Config config, 89 std::unique_ptr<DataTransferClient>* out); 90 }; 91 92 // Server for communicating with the tf.data service transfer client. 93 class DataTransferServer { 94 public: 95 using GetElementT = 96 std::function<Status(const GetElementRequest*, GetElementResult*)>; 97 virtual ~DataTransferServer() = default; 98 99 // Starts DataTransferServer, it should be available for requests afterwards. 100 virtual Status Start() = 0; 101 102 // Return the port that this server is listening on. 103 virtual int get_port() = 0; 104 105 // Register a DataTransferServer factory under `name`. 106 static void Register( 107 std::string name, 108 std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory); 109 110 // Builds a DataTransferServer from the factory registered with `name`. 111 static Status Build(std::string name, GetElementT get_element, 112 std::shared_ptr<DataTransferServer>* out); 113 }; 114 115 } // namespace data 116 } // namespace tensorflow 117 118 #endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 119