1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_ 17 18 #include <string> 19 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/core/data/service/common.pb.h" 22 #include "tensorflow/core/framework/dataset_options.pb.h" 23 #include "tensorflow/core/platform/status.h" 24 #include "tensorflow/core/platform/statusor.h" 25 #include "tensorflow/core/platform/types.h" 26 #include "tensorflow/core/protobuf/data_service.pb.h" 27 28 namespace tensorflow { 29 namespace data { 30 31 // Increment this when making backwards-incompatible changes to communication 32 // between tf.data servers. 33 constexpr int kDataServiceVersion = 3; 34 35 // Returns true if `processing_mode` specifies no sharding policy. 36 bool IsNoShard(const ProcessingModeDef& processing_mode); 37 38 // Returns true if `processing_mode` is dynamic sharding. 39 bool IsDynamicShard(const ProcessingModeDef& processing_mode); 40 41 // Returns true if `processing_mode` is static sharding. 42 bool IsStaticShard(const ProcessingModeDef& processing_mode); 43 44 // Returns an internal error if `processing_mode` is invalid. 45 Status ValidateProcessingMode(const ProcessingModeDef& processing_mode); 46 47 // Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an 48 // internal error if `sharding_policy` is not supported. 49 StatusOr<AutoShardPolicy> ToAutoShardPolicy( 50 ProcessingModeDef::ShardingPolicy sharding_policy); 51 52 // Parses a string representing a `TargetWorkers` (case-insensitive). 53 // Returns InvalidArgument if the string is not recognized. 54 StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s); 55 56 // Converts a `TargetWorkers` enum to string. 57 std::string TargetWorkersToString(TargetWorkers target_workers); 58 59 // Base class for data service clients. Data service clients are 60 // threadsafe. 61 class DataServiceClientBase { 62 public: DataServiceClientBase(const std::string & address,const std::string & protocol)63 DataServiceClientBase(const std::string& address, const std::string& protocol) 64 : address_(address), protocol_(protocol) {} 65 66 virtual ~DataServiceClientBase() = default; 67 // Not copyable or movable. 68 DataServiceClientBase(const DataServiceClientBase&) = delete; 69 DataServiceClientBase& operator=(const DataServiceClientBase&) = delete; 70 71 // Initializes the client. Calling `Initialize()` is not required since the 72 // first RPC will perform any necessary initialization. However, it can be 73 // useful to call `Initialize()` proactively so that any errors that happen 74 // during initialization can be surfaced earlier. Initialize()75 Status Initialize() { return EnsureInitialized(); } 76 77 protected: 78 // Initializes the client if it isn't already initialized. 79 virtual Status EnsureInitialized() = 0; 80 81 const std::string address_; 82 const std::string protocol_; 83 }; 84 85 } // namespace data 86 } // namespace tensorflow 87 88 #endif // TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_ 89