1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/common.h"
16
17 #include <string>
18
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/framework/dataset_options.pb.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/protobuf/data_service.pb.h"
26
27 namespace tensorflow {
28 namespace data {
29
30 namespace {
31 constexpr const char kAuto[] = "AUTO";
32 constexpr const char kAny[] = "ANY";
33 constexpr const char kLocal[] = "LOCAL";
34 } // namespace
35
IsNoShard(const ProcessingModeDef & processing_mode)36 bool IsNoShard(const ProcessingModeDef& processing_mode) {
37 return processing_mode.sharding_policy() == ProcessingModeDef::OFF;
38 }
39
IsDynamicShard(const ProcessingModeDef & processing_mode)40 bool IsDynamicShard(const ProcessingModeDef& processing_mode) {
41 return processing_mode.sharding_policy() == ProcessingModeDef::DYNAMIC;
42 }
43
IsStaticShard(const ProcessingModeDef & processing_mode)44 bool IsStaticShard(const ProcessingModeDef& processing_mode) {
45 return processing_mode.sharding_policy() == ProcessingModeDef::FILE ||
46 processing_mode.sharding_policy() == ProcessingModeDef::DATA ||
47 processing_mode.sharding_policy() == ProcessingModeDef::FILE_OR_DATA ||
48 processing_mode.sharding_policy() == ProcessingModeDef::HINT;
49 }
50
ValidateProcessingMode(const ProcessingModeDef & processing_mode)51 Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) {
52 if (!IsNoShard(processing_mode) && !IsDynamicShard(processing_mode) &&
53 !IsStaticShard(processing_mode)) {
54 return errors::Internal(
55 "ProcessingMode ", processing_mode.ShortDebugString(),
56 " does not "
57 "specify a valid sharding policy. Please add the policy to either "
58 "`IsDynamicShard` or `IsStaticShard` (i.e., auto-shard).");
59 }
60 return Status::OK();
61 }
62
ToAutoShardPolicy(const ProcessingModeDef::ShardingPolicy sharding_policy)63 StatusOr<AutoShardPolicy> ToAutoShardPolicy(
64 const ProcessingModeDef::ShardingPolicy sharding_policy) {
65 switch (sharding_policy) {
66 case ProcessingModeDef::FILE:
67 return AutoShardPolicy::FILE;
68 case ProcessingModeDef::DATA:
69 return AutoShardPolicy::DATA;
70 case ProcessingModeDef::FILE_OR_DATA:
71 return AutoShardPolicy::AUTO;
72 case ProcessingModeDef::HINT:
73 return AutoShardPolicy::HINT;
74 case ProcessingModeDef::DYNAMIC:
75 case ProcessingModeDef::OFF:
76 return AutoShardPolicy::OFF;
77 default:
78 return errors::Internal(
79 "tf.data service sharding policy ",
80 ProcessingModeDef::ShardingPolicy_Name(sharding_policy),
81 " is not convertible to a valid auto-shard policy. If you're "
82 "defining a new sharding policy, please update the policy mapping.");
83 }
84 }
85
ParseTargetWorkers(absl::string_view s)86 StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s) {
87 std::string str_upper = absl::AsciiStrToUpper(s);
88 if (str_upper.empty() || str_upper == kAuto) {
89 return TARGET_WORKERS_AUTO;
90 }
91 if (str_upper == kAny) {
92 return TARGET_WORKERS_ANY;
93 }
94 if (str_upper == kLocal) {
95 return TARGET_WORKERS_LOCAL;
96 }
97 return errors::InvalidArgument("Unrecognized target workers: ", s);
98 }
99
TargetWorkersToString(TargetWorkers target_workers)100 std::string TargetWorkersToString(TargetWorkers target_workers) {
101 switch (target_workers) {
102 case TARGET_WORKERS_AUTO:
103 return kAuto;
104 case TARGET_WORKERS_ANY:
105 return kAny;
106 case TARGET_WORKERS_LOCAL:
107 return kLocal;
108 default:
109 DCHECK(false);
110 return "UNKNOWN";
111 }
112 }
113
114 } // namespace data
115 } // namespace tensorflow
116