• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/common.h"
16 
17 #include <string>
18 
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/framework/dataset_options.pb.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/protobuf/data_service.pb.h"
26 
27 namespace tensorflow {
28 namespace data {
29 
30 namespace {
31 constexpr const char kAuto[] = "AUTO";
32 constexpr const char kAny[] = "ANY";
33 constexpr const char kLocal[] = "LOCAL";
34 }  // namespace
35 
IsNoShard(const ProcessingModeDef & processing_mode)36 bool IsNoShard(const ProcessingModeDef& processing_mode) {
37   return processing_mode.sharding_policy() == ProcessingModeDef::OFF;
38 }
39 
IsDynamicShard(const ProcessingModeDef & processing_mode)40 bool IsDynamicShard(const ProcessingModeDef& processing_mode) {
41   return processing_mode.sharding_policy() == ProcessingModeDef::DYNAMIC;
42 }
43 
IsStaticShard(const ProcessingModeDef & processing_mode)44 bool IsStaticShard(const ProcessingModeDef& processing_mode) {
45   return processing_mode.sharding_policy() == ProcessingModeDef::FILE ||
46          processing_mode.sharding_policy() == ProcessingModeDef::DATA ||
47          processing_mode.sharding_policy() == ProcessingModeDef::FILE_OR_DATA ||
48          processing_mode.sharding_policy() == ProcessingModeDef::HINT;
49 }
50 
ValidateProcessingMode(const ProcessingModeDef & processing_mode)51 Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) {
52   if (!IsNoShard(processing_mode) && !IsDynamicShard(processing_mode) &&
53       !IsStaticShard(processing_mode)) {
54     return errors::Internal(
55         "ProcessingMode ", processing_mode.ShortDebugString(),
56         " does not "
57         "specify a valid sharding policy. Please add the policy to either "
58         "`IsDynamicShard` or `IsStaticShard` (i.e., auto-shard).");
59   }
60   return Status::OK();
61 }
62 
ToAutoShardPolicy(const ProcessingModeDef::ShardingPolicy sharding_policy)63 StatusOr<AutoShardPolicy> ToAutoShardPolicy(
64     const ProcessingModeDef::ShardingPolicy sharding_policy) {
65   switch (sharding_policy) {
66     case ProcessingModeDef::FILE:
67       return AutoShardPolicy::FILE;
68     case ProcessingModeDef::DATA:
69       return AutoShardPolicy::DATA;
70     case ProcessingModeDef::FILE_OR_DATA:
71       return AutoShardPolicy::AUTO;
72     case ProcessingModeDef::HINT:
73       return AutoShardPolicy::HINT;
74     case ProcessingModeDef::DYNAMIC:
75     case ProcessingModeDef::OFF:
76       return AutoShardPolicy::OFF;
77     default:
78       return errors::Internal(
79           "tf.data service sharding policy ",
80           ProcessingModeDef::ShardingPolicy_Name(sharding_policy),
81           " is not convertible to a valid auto-shard policy. If you're "
82           "defining a new sharding policy, please update the policy mapping.");
83   }
84 }
85 
ParseTargetWorkers(absl::string_view s)86 StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s) {
87   std::string str_upper = absl::AsciiStrToUpper(s);
88   if (str_upper.empty() || str_upper == kAuto) {
89     return TARGET_WORKERS_AUTO;
90   }
91   if (str_upper == kAny) {
92     return TARGET_WORKERS_ANY;
93   }
94   if (str_upper == kLocal) {
95     return TARGET_WORKERS_LOCAL;
96   }
97   return errors::InvalidArgument("Unrecognized target workers: ", s);
98 }
99 
TargetWorkersToString(TargetWorkers target_workers)100 std::string TargetWorkersToString(TargetWorkers target_workers) {
101   switch (target_workers) {
102     case TARGET_WORKERS_AUTO:
103       return kAuto;
104     case TARGET_WORKERS_ANY:
105       return kAny;
106     case TARGET_WORKERS_LOCAL:
107       return kLocal;
108     default:
109       DCHECK(false);
110       return "UNKNOWN";
111   }
112 }
113 
114 }  // namespace data
115 }  // namespace tensorflow
116