1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
17
18 #include <type_traits>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "llvm/ADT/Optional.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
ParseOutputArrayInfo(absl::string_view array_names,std::vector<string> * outputs)37 Status ParseOutputArrayInfo(absl::string_view array_names,
38 std::vector<string>* outputs) {
39 TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs));
40 return Status::OK();
41 }
42
ParseOutputArrayInfo(const std::vector<string> & output_names,std::vector<string> * outputs)43 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
44 std::vector<string>* outputs) {
45 for (auto& output_name : output_names) {
46 if (output_name.empty()) continue;
47 outputs->push_back(output_name);
48 }
49 return Status::OK();
50 }
51
ParseInputArrayInfo(absl::string_view array_names,absl::string_view data_types,absl::string_view shapes,GraphImportConfig::InputArrays * inputs)52 Status ParseInputArrayInfo(absl::string_view array_names,
53 absl::string_view data_types,
54 absl::string_view shapes,
55 GraphImportConfig::InputArrays* inputs) {
56 std::vector<string> node_names;
57 std::vector<string> node_dtypes;
58 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
59 TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
60 TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
61 TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
62 return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
63 }
64
ParseInputArrayInfo(const std::vector<string> & node_names,const std::vector<string> & node_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & node_shapes,GraphImportConfig::InputArrays * inputs)65 Status ParseInputArrayInfo(
66 const std::vector<string>& node_names,
67 const std::vector<string>& node_dtypes,
68 const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
69 GraphImportConfig::InputArrays* inputs) {
70 std::vector<std::string> used_node_dtypes;
71 if (node_dtypes.empty()) {
72 // Mark all the node dtypes Invalid, so the importer can handle them by
73 // using the type from the graph.
74 used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID));
75 } else if (node_names.size() == node_dtypes.size()) {
76 for (const auto& dtype : node_dtypes) {
77 if (dtype.empty()) {
78 used_node_dtypes.push_back(DataType_Name(DT_INVALID));
79 } else if (dtype != DataType_Name(DT_INVALID)) {
80 used_node_dtypes.push_back(dtype);
81 } else {
82 return errors::FailedPrecondition(
83 "Use '' if want to use the type from graph.");
84 }
85 }
86 } else {
87 return errors::FailedPrecondition(absl::StrCat(
88 "Unmatched node array and data type numbers (#arrays ",
89 node_names.size(), ", #data_types ", node_dtypes.size(), ")"));
90 }
91
92 if (!node_shapes.empty() && node_names.size() != node_shapes.size()) {
93 return errors::FailedPrecondition(absl::StrCat(
94 "Unmatched node array and shape numbers (#arrays ", node_names.size(),
95 ", #input_shapes ", node_shapes.size(), ")"));
96 }
97
98 // StringMap doesn't support reserve else reserve input map size here.
99 for (int i = 0, end = node_names.size(); i < end; i++) {
100 auto& name = node_names[i];
101 if (name.empty()) continue;
102
103 auto it_inserted_pair = inputs->insert({name, {}});
104 if (!it_inserted_pair.second)
105 return errors::FailedPrecondition(
106 absl::StrCat("tensor ", name, " is repeated in the arrays flag"));
107
108 ArrayInfo& info = it_inserted_pair.first->second;
109 if (!DataType_Parse(used_node_dtypes[i], &info.imported_dtype)) {
110 return errors::FailedPrecondition(
111 absl::StrCat("Invalid node type '", node_dtypes[i], "'"));
112 }
113
114 if (!node_shapes.empty()) {
115 if (!node_shapes[i].hasValue()) {
116 info.shape.set_unknown_rank(true);
117 continue;
118 }
119 for (auto& dim : node_shapes[i].getValue()) {
120 info.shape.add_dim()->set_size(dim);
121 }
122 }
123 }
124 return Status::OK();
125 }
126
ParseNodeShapes(absl::string_view shapes_str,std::vector<llvm::Optional<std::vector<int>>> & shapes_vector)127 Status ParseNodeShapes(
128 absl::string_view shapes_str,
129 std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
130 shapes_vector.clear();
131 if (!shapes_str.empty()) {
132 std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
133 for (int i = 0; i < node_shapes_str.size(); i++) {
134 if (node_shapes_str[i] == "*") {
135 shapes_vector.push_back(llvm::None);
136 continue;
137 }
138 std::vector<int> dims;
139 for (const absl::string_view dim_str :
140 absl::StrSplit(node_shapes_str[i], ',')) {
141 // Treats empty input shape as scalar
142 if (dim_str.empty()) continue;
143 if (dim_str == "?") {
144 dims.push_back(-1);
145 continue;
146 }
147 int size;
148 TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
149 dims.push_back(size);
150 }
151 shapes_vector.push_back(dims);
152 }
153 }
154 return Status::OK();
155 }
156
ParseNodeNames(absl::string_view names_str,std::vector<std::string> & names_vector)157 Status ParseNodeNames(absl::string_view names_str,
158 std::vector<std::string>& names_vector) {
159 names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty());
160 return Status::OK();
161 }
162
ParseNodeDataTypes(absl::string_view data_types_str,std::vector<std::string> & data_type_vector)163 Status ParseNodeDataTypes(absl::string_view data_types_str,
164 std::vector<std::string>& data_type_vector) {
165 data_type_vector.clear();
166 if (!data_types_str.empty()) {
167 data_type_vector = absl::StrSplit(data_types_str, ',');
168 }
169 return Status::OK();
170 }
171
172 } // namespace tensorflow
173