• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/runtime_fallback/bef_executor_flags.h"
17 
18 #include "absl/strings/str_split.h"
19 #include "absl/strings/string_view.h"
20 
21 ABSL_FLAG(std::string, input_filename, tfrt::kDefaultInputFilename,
22           "mlir input filename (default '-' for stdin)");
23 ABSL_FLAG(std::string, shared_libs, "",
24           "comma-separated list of dynamic libraries with ops");
25 ABSL_FLAG(std::string, functions, "",
26           "comma-separated list of mlir functions to run");
27 ABSL_FLAG(std::string, test_init_function, "",
28           "init function that will be invoked as part of "
29           "initialization, before invoking any other MLIR functions even if it "
30           "is not specified in --functions flag.");
31 ABSL_FLAG(std::string, work_queue_type, "s",
32           "type of work queue (s(default), mstd, ...)");
33 ABSL_FLAG(tfrt::HostAllocatorTypeWrapper, host_allocator_type,
34           {tfrt::HostAllocatorType::kLeakCheckMalloc},
35           "type of host allocator (malloc, profiled_allocator, "
36           "leak_check_allocator(default))");
37 
38 namespace tfrt {
39 
40 const char kDefaultInputFilename[] = "-";
41 
42 // AbslParseFlag/AbslUnparseFlag need to be in tfrt namespace for ADL to work.
43 
AbslParseFlag(absl::string_view text,tfrt::HostAllocatorTypeWrapper * host_allocator_type,std::string * error)44 bool AbslParseFlag(absl::string_view text,
45                    tfrt::HostAllocatorTypeWrapper* host_allocator_type,
46                    std::string* error) {
47   if (text == "malloc") {
48     *host_allocator_type = {tfrt::HostAllocatorType::kMalloc};
49     return true;
50   }
51   if (text == "profiled_allocator") {
52     *host_allocator_type = {tfrt::HostAllocatorType::kProfiledMalloc};
53     return true;
54   }
55   if (text == "leak_check_allocator") {
56     *host_allocator_type = {tfrt::HostAllocatorType::kLeakCheckMalloc};
57     return true;
58   }
59   *error = "Unknown value for tfrt::HostAllocatorType";
60   return false;
61 }
62 
AbslUnparseFlag(tfrt::HostAllocatorTypeWrapper host_allocator_type)63 std::string AbslUnparseFlag(
64     tfrt::HostAllocatorTypeWrapper host_allocator_type) {
65   switch (host_allocator_type) {
66     case tfrt::HostAllocatorType::kMalloc:
67       return "malloc";
68     case tfrt::HostAllocatorType::kTestFixedSizeMalloc:
69       return "test_fixed_size_1k";
70     case tfrt::HostAllocatorType::kProfiledMalloc:
71       return "profiled_allocator";
72     case tfrt::HostAllocatorType::kLeakCheckMalloc:
73       return "leak_check_allocator";
74   }
75 }
76 }  // namespace tfrt
77