1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
17
18 #include <ostream>
19 #include <sstream>
20 #include <type_traits>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "llvm/ADT/Optional.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace tensorflow {
38
str() const39 std::string GraphImportConfig::str() const {
40 std::ostringstream ss;
41
42 ss << "graph_func_name: " << graph_func_name;
43 InputArrays inputs;
44 ss << "\ninputs: ";
45 for (auto& it : inputs) {
46 ss << "\n\t" << it.first << " -> "
47 << DataTypeString(it.second.imported_dtype) << " "
48 << it.second.shape.DebugString();
49 }
50 ss << "\noutputs:";
51 for (auto& output : outputs) ss << " " << output;
52 ss << "\ncontrol_outputs:";
53 for (auto& output : control_outputs) ss << " " << output;
54 ss << "\nprune_unused_nodes: " << prune_unused_nodes;
55 ss << "\nconvert_legacy_fed_inputs: " << convert_legacy_fed_inputs;
56 ss << "\ngraph_as_function: " << graph_as_function;
57 ss << "\nupgrade_legacy: " << upgrade_legacy;
58 ss << "\nrestrict_functionalization_to_tpu_nodes: "
59 << restrict_functionalization_to_tpu_nodes;
60 ss << "\nenable_shape_inference: " << enable_shape_inference;
61
62 return ss.str();
63 }
64
ParseOutputArrayInfo(absl::string_view array_names,std::vector<string> * outputs)65 Status ParseOutputArrayInfo(absl::string_view array_names,
66 std::vector<string>* outputs) {
67 TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs));
68 return Status::OK();
69 }
70
ParseOutputArrayInfo(const std::vector<string> & output_names,std::vector<string> * outputs)71 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
72 std::vector<string>* outputs) {
73 for (auto& output_name : output_names) {
74 if (output_name.empty()) continue;
75 outputs->push_back(output_name);
76 }
77 return Status::OK();
78 }
79
ParseInputArrayInfo(absl::string_view array_names,absl::string_view data_types,absl::string_view shapes,GraphImportConfig::InputArrays * inputs)80 Status ParseInputArrayInfo(absl::string_view array_names,
81 absl::string_view data_types,
82 absl::string_view shapes,
83 GraphImportConfig::InputArrays* inputs) {
84 std::vector<string> node_names;
85 std::vector<string> node_dtypes;
86 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
87 TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
88 TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
89 TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
90 return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
91 }
92
ParseInputArrayInfo(const std::vector<string> & node_names,const std::vector<string> & node_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & node_shapes,GraphImportConfig::InputArrays * inputs)93 Status ParseInputArrayInfo(
94 const std::vector<string>& node_names,
95 const std::vector<string>& node_dtypes,
96 const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
97 GraphImportConfig::InputArrays* inputs) {
98 std::vector<std::string> used_node_dtypes;
99 if (node_dtypes.empty()) {
100 // Mark all the node dtypes Invalid, so the importer can handle them by
101 // using the type from the graph.
102 used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID));
103 } else if (node_names.size() == node_dtypes.size()) {
104 for (const auto& dtype : node_dtypes) {
105 if (dtype.empty()) {
106 used_node_dtypes.push_back(DataType_Name(DT_INVALID));
107 } else if (dtype != DataType_Name(DT_INVALID)) {
108 used_node_dtypes.push_back(dtype);
109 } else {
110 return errors::FailedPrecondition(
111 "Use '' if want to use the type from graph.");
112 }
113 }
114 } else {
115 return errors::FailedPrecondition(absl::StrCat(
116 "Unmatched node array and data type numbers (#arrays ",
117 node_names.size(), ", #data_types ", node_dtypes.size(), ")"));
118 }
119
120 if (!node_shapes.empty() && node_names.size() != node_shapes.size()) {
121 return errors::FailedPrecondition(absl::StrCat(
122 "Unmatched node array and shape numbers (#arrays ", node_names.size(),
123 ", #input_shapes ", node_shapes.size(), ")"));
124 }
125
126 // StringMap doesn't support reserve else reserve input map size here.
127 for (int i = 0, end = node_names.size(); i < end; i++) {
128 auto& name = node_names[i];
129 if (name.empty()) continue;
130
131 auto it_inserted_pair = inputs->insert({name, {}});
132 if (!it_inserted_pair.second)
133 return errors::FailedPrecondition(
134 absl::StrCat("tensor ", name, " is repeated in the arrays flag"));
135
136 ArrayInfo& info = it_inserted_pair.first->second;
137 if (!DataType_Parse(used_node_dtypes[i], &info.imported_dtype)) {
138 return errors::FailedPrecondition(
139 absl::StrCat("Invalid node type '", node_dtypes[i], "'"));
140 }
141
142 if (!node_shapes.empty()) {
143 if (!node_shapes[i].hasValue()) {
144 info.shape.set_unknown_rank(true);
145 continue;
146 }
147 for (auto& dim : node_shapes[i].getValue()) {
148 info.shape.add_dim()->set_size(dim);
149 }
150 }
151 }
152 return Status::OK();
153 }
154
ParseNodeShapes(absl::string_view shapes_str,std::vector<llvm::Optional<std::vector<int>>> & shapes_vector)155 Status ParseNodeShapes(
156 absl::string_view shapes_str,
157 std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
158 shapes_vector.clear();
159 if (!shapes_str.empty()) {
160 std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
161 for (int i = 0; i < node_shapes_str.size(); i++) {
162 if (node_shapes_str[i] == "*") {
163 shapes_vector.push_back(llvm::None);
164 continue;
165 }
166 std::vector<int> dims;
167 for (const absl::string_view dim_str :
168 absl::StrSplit(node_shapes_str[i], ',')) {
169 // Treats empty input shape as scalar
170 if (dim_str.empty()) continue;
171 if (dim_str == "?") {
172 dims.push_back(-1);
173 continue;
174 }
175 int size;
176 TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
177 dims.push_back(size);
178 }
179 shapes_vector.push_back(dims);
180 }
181 }
182 return Status::OK();
183 }
184
ParseNodeNames(absl::string_view names_str,std::vector<std::string> & names_vector)185 Status ParseNodeNames(absl::string_view names_str,
186 std::vector<std::string>& names_vector) {
187 names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty());
188 return Status::OK();
189 }
190
ParseNodeDataTypes(absl::string_view data_types_str,std::vector<std::string> & data_type_vector)191 Status ParseNodeDataTypes(absl::string_view data_types_str,
192 std::vector<std::string>& data_type_vector) {
193 data_type_vector.clear();
194 if (!data_types_str.empty()) {
195 data_type_vector = absl::StrSplit(data_types_str, ',');
196 }
197 return Status::OK();
198 }
199
200 } // namespace tensorflow
201