1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
17 #define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
18
19 #include <vector>
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/xla.pb.h"
24
25 namespace xla {
26
27 template <typename T>
parse_xla_backend_extra_options(T * extra_options_map,string comma_separated_values)28 void parse_xla_backend_extra_options(T* extra_options_map,
29 string comma_separated_values) {
30 std::vector<string> extra_options_parts =
31 absl::StrSplit(comma_separated_values, ',');
32
33 // The flag contains a comma-separated list of options; some options
34 // have arguments following "=", some don't.
35 for (const auto& part : extra_options_parts) {
36 size_t eq_pos = part.find_first_of('=');
37 if (eq_pos == string::npos) {
38 (*extra_options_map)[part] = "";
39 } else {
40 string value = "";
41 if (eq_pos + 1 < part.size()) {
42 value = part.substr(eq_pos + 1);
43 }
44 (*extra_options_map)[part.substr(0, eq_pos)] = value;
45 }
46 }
47 }
48
49 // The --xla_reduce_precision option has the format "LOCATION=E,M:OPS;NAME",
50 // where LOCATION is an HloReducePrecisionOptions::location, E and M are
51 // integers for the exponent and matissa bit counts respectively, and OPS and
52 // NAMES are comma-separated of the operation types and names to which to
53 // attach the reduce-precision operations. The OPS values are matches to the
54 // strings produced by HloOpcodeString, while the NAME values are arbitrary
55 // strings subject to the requirements that they not contain any of "=,:;".
56 // The NAME string (with its preceding semicolon) is optional.
parse_xla_reduce_precision_option(HloReducePrecisionOptions * options,string option_string)57 inline bool parse_xla_reduce_precision_option(
58 HloReducePrecisionOptions* options, string option_string) {
59 // Split off "LOCATION" from remainder of string.
60 std::vector<string> eq_split = absl::StrSplit(option_string, '=');
61 if (eq_split.size() != 2) {
62 return false;
63 }
64 string& location = eq_split[0];
65 if (location == "OP_INPUTS") {
66 options->set_location(HloReducePrecisionOptions::OP_INPUTS);
67 } else if (location == "OP_OUTPUTS") {
68 options->set_location(HloReducePrecisionOptions::OP_OUTPUTS);
69 } else if (location == "UNFUSED_OP_OUTPUTS") {
70 options->set_location(HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS);
71 } else if (location == "FUSION_INPUTS_BY_CONTENT") {
72 options->set_location(HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT);
73 } else if (location == "FUSION_OUTPUTS_BY_CONTENT") {
74 options->set_location(HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT);
75 } else {
76 return false;
77 }
78
79 // Split off "E,M" from remainder of string.
80 std::vector<string> colon_split = absl::StrSplit(eq_split[1], ':');
81 if (colon_split.size() != 2) {
82 return false;
83 }
84
85 // Split E and M, and parse.
86 std::vector<int32> bitsizes;
87 for (const auto& s : absl::StrSplit(colon_split[0], ',')) {
88 bitsizes.emplace_back();
89 if (!absl::SimpleAtoi(s, &bitsizes.back())) {
90 return false;
91 }
92 }
93 options->set_exponent_bits(bitsizes[0]);
94 options->set_mantissa_bits(bitsizes[1]);
95
96 // Split off OPS comma-separated list from remainder of string, if the
97 // remainder exists.
98 std::vector<string> semicolon_split = absl::StrSplit(colon_split[1], ';');
99 if (semicolon_split.size() > 2) {
100 return false;
101 }
102 // The opcode values are either 'all' (meaning all opcodes), or matches to
103 // the strings returned by HloOpcodeString. An empty string is also
104 // interpreted as 'all', for convenience. Note that 'all' may not be part
105 // of a comma-separated list; it must stand alone.
106 string& opcode_string = semicolon_split[0];
107 if (opcode_string == "" || opcode_string == "all") {
108 for (int i = 0; i < HloOpcodeCount(); i++) {
109 options->add_opcodes_to_suffix(i);
110 }
111 } else {
112 std::vector<string> opcodes = absl::StrSplit(opcode_string, ',');
113 for (const string& opcode : opcodes) {
114 bool found = false;
115 for (int i = 0; i < HloOpcodeCount(); i++) {
116 if (opcode == HloOpcodeString(static_cast<HloOpcode>(i))) {
117 options->add_opcodes_to_suffix(i);
118 found = true;
119 break;
120 }
121 }
122 if (!found) {
123 return false;
124 }
125 }
126 }
127
128 // Process the NAMES string, if it exists.
129 if (semicolon_split.size() == 2) {
130 std::vector<string> opnames = absl::StrSplit(semicolon_split[1], ',');
131 for (const string& opname : opnames) {
132 if (opname.length() > 0) {
133 options->add_opname_substrings_to_suffix(opname);
134 }
135 }
136 }
137
138 return true;
139 }
140
141 } // namespace xla
142
143 #endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_
144