• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/data/dataset.pb.h"
27 #include "tensorflow/core/data/service/worker.pb.h"
28 #include "tensorflow/core/framework/dataset.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/status.h"
32 
33 namespace tensorflow {
34 namespace data {
35 
36 // The result of a GetElement request. Exactly one of the following will be
37 // true: (1) `components` is nonempty (2) `end_of_sequence` is true (3) `skip`
38 // is true.
39 struct GetElementResult {
40   GetElementResult() = default;
41   GetElementResult(const GetElementResult&) = delete;
42   GetElementResult& operator=(const GetElementResult&) = delete;
43   GetElementResult(GetElementResult&&) = default;
44   GetElementResult& operator=(GetElementResult&&) = default;
45 
46   // Creates a copy of this result. This is used to create multiple copies of
47   // the same cached value.
48   GetElementResult Copy() const;
49 
50   // Estimated memory used by this object, measured in bytes.
51   size_t EstimatedMemoryUsageBytes() const;
52 
53   // A dataset element produced by a GetElement request.
54   std::vector<Tensor> components;
55   // The element's index within the task it came from.
56   int64_t element_index = 0;
57   // If true, indicates that there is no more data to read.
58   bool end_of_sequence = false;
59   // If true, indicates that there is still data, but the caller should skip
60   // reading from the worker. This is used for load balancing when doing round
61   // robin reads.
62   bool skip = false;
63 };
64 
65 // Client for communicating with the tf.data service transfer server.
66 class DataTransferClient {
67  public:
68   struct Config {
69     absl::string_view protocol;
70     std::string address;
71   };
72   using FactoryT =
73       std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
74   virtual ~DataTransferClient() = default;
75 
76   // Fetches the next element.
77   virtual Status GetElement(const GetElementRequest& req,
78                             GetElementResult& result) = 0;
79 
80   // Makes a best effort to cancel all outstanding calls in progress for the
81   // client, and causes further calls to return Cancelled status.
82   virtual void TryCancel() = 0;
83 
84   // Registers a DataTransferClient factory under `name`.
85   static void Register(std::string name, FactoryT factory);
86 
87   // Builds a DataTransferClient from the factory registered under `name`.
88   static Status Build(std::string name, Config config,
89                       std::unique_ptr<DataTransferClient>* out);
90 };
91 
92 // Server for communicating with the tf.data service transfer client.
93 class DataTransferServer {
94  public:
95   using GetElementT =
96       std::function<Status(const GetElementRequest*, GetElementResult*)>;
97   virtual ~DataTransferServer() = default;
98 
99   // Starts DataTransferServer, it should be available for requests afterwards.
100   virtual Status Start() = 0;
101 
102   // Return the port that this server is listening on.
103   virtual int get_port() = 0;
104 
105   // Register a DataTransferServer factory under `name`.
106   static void Register(
107       std::string name,
108       std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory);
109 
110   // Builds a DataTransferServer from the factory registered with `name`.
111   static Status Build(std::string name, GetElementT get_element,
112                       std::shared_ptr<DataTransferServer>* out);
113 };
114 
115 }  // namespace data
116 }  // namespace tensorflow
117 
118 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_
119