• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
17 
18 #include <ostream>
19 #include <sstream>
20 #include <type_traits>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/inlined_vector.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/strings/str_split.h"
28 #include "llvm/ADT/Optional.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace tensorflow {
38 
str() const39 std::string GraphImportConfig::str() const {
40   std::ostringstream ss;
41 
42   ss << "graph_func_name: " << graph_func_name;
43   InputArrays inputs;
44   ss << "\ninputs: ";
45   for (auto& it : inputs) {
46     ss << "\n\t" << it.first << " -> "
47        << DataTypeString(it.second.imported_dtype) << " "
48        << it.second.shape.DebugString();
49   }
50   ss << "\noutputs:";
51   for (auto& output : outputs) ss << " " << output;
52   ss << "\ncontrol_outputs:";
53   for (auto& output : control_outputs) ss << " " << output;
54   ss << "\nprune_unused_nodes: " << prune_unused_nodes;
55   ss << "\nconvert_legacy_fed_inputs: " << convert_legacy_fed_inputs;
56   ss << "\ngraph_as_function: " << graph_as_function;
57   ss << "\nupgrade_legacy: " << upgrade_legacy;
58   ss << "\nrestrict_functionalization_to_tpu_nodes: "
59      << restrict_functionalization_to_tpu_nodes;
60   ss << "\nenable_shape_inference: " << enable_shape_inference;
61 
62   return ss.str();
63 }
64 
ParseOutputArrayInfo(absl::string_view array_names,std::vector<string> * outputs)65 Status ParseOutputArrayInfo(absl::string_view array_names,
66                             std::vector<string>* outputs) {
67   TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs));
68   return Status::OK();
69 }
70 
ParseOutputArrayInfo(const std::vector<string> & output_names,std::vector<string> * outputs)71 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
72                             std::vector<string>* outputs) {
73   for (auto& output_name : output_names) {
74     if (output_name.empty()) continue;
75     outputs->push_back(output_name);
76   }
77   return Status::OK();
78 }
79 
ParseInputArrayInfo(absl::string_view array_names,absl::string_view data_types,absl::string_view shapes,GraphImportConfig::InputArrays * inputs)80 Status ParseInputArrayInfo(absl::string_view array_names,
81                            absl::string_view data_types,
82                            absl::string_view shapes,
83                            GraphImportConfig::InputArrays* inputs) {
84   std::vector<string> node_names;
85   std::vector<string> node_dtypes;
86   std::vector<llvm::Optional<std::vector<int>>> node_shapes;
87   TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
88   TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
89   TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
90   return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
91 }
92 
ParseInputArrayInfo(const std::vector<string> & node_names,const std::vector<string> & node_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & node_shapes,GraphImportConfig::InputArrays * inputs)93 Status ParseInputArrayInfo(
94     const std::vector<string>& node_names,
95     const std::vector<string>& node_dtypes,
96     const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
97     GraphImportConfig::InputArrays* inputs) {
98   std::vector<std::string> used_node_dtypes;
99   if (node_dtypes.empty()) {
100     // Mark all the node dtypes Invalid, so the importer can handle them by
101     // using the type from the graph.
102     used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID));
103   } else if (node_names.size() == node_dtypes.size()) {
104     for (const auto& dtype : node_dtypes) {
105       if (dtype.empty()) {
106         used_node_dtypes.push_back(DataType_Name(DT_INVALID));
107       } else if (dtype != DataType_Name(DT_INVALID)) {
108         used_node_dtypes.push_back(dtype);
109       } else {
110         return errors::FailedPrecondition(
111             "Use '' if want to use the type from graph.");
112       }
113     }
114   } else {
115     return errors::FailedPrecondition(absl::StrCat(
116         "Unmatched node array and data type numbers (#arrays ",
117         node_names.size(), ", #data_types ", node_dtypes.size(), ")"));
118   }
119 
120   if (!node_shapes.empty() && node_names.size() != node_shapes.size()) {
121     return errors::FailedPrecondition(absl::StrCat(
122         "Unmatched node array and shape numbers (#arrays ", node_names.size(),
123         ", #input_shapes ", node_shapes.size(), ")"));
124   }
125 
126   // StringMap doesn't support reserve else reserve input map size here.
127   for (int i = 0, end = node_names.size(); i < end; i++) {
128     auto& name = node_names[i];
129     if (name.empty()) continue;
130 
131     auto it_inserted_pair = inputs->insert({name, {}});
132     if (!it_inserted_pair.second)
133       return errors::FailedPrecondition(
134           absl::StrCat("tensor ", name, " is repeated in the arrays flag"));
135 
136     ArrayInfo& info = it_inserted_pair.first->second;
137     if (!DataType_Parse(used_node_dtypes[i], &info.imported_dtype)) {
138       return errors::FailedPrecondition(
139           absl::StrCat("Invalid node type '", node_dtypes[i], "'"));
140     }
141 
142     if (!node_shapes.empty()) {
143       if (!node_shapes[i].hasValue()) {
144         info.shape.set_unknown_rank(true);
145         continue;
146       }
147       for (auto& dim : node_shapes[i].getValue()) {
148         info.shape.add_dim()->set_size(dim);
149       }
150     }
151   }
152   return Status::OK();
153 }
154 
ParseNodeShapes(absl::string_view shapes_str,std::vector<llvm::Optional<std::vector<int>>> & shapes_vector)155 Status ParseNodeShapes(
156     absl::string_view shapes_str,
157     std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
158   shapes_vector.clear();
159   if (!shapes_str.empty()) {
160     std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
161     for (int i = 0; i < node_shapes_str.size(); i++) {
162       if (node_shapes_str[i] == "*") {
163         shapes_vector.push_back(llvm::None);
164         continue;
165       }
166       std::vector<int> dims;
167       for (const absl::string_view dim_str :
168            absl::StrSplit(node_shapes_str[i], ',')) {
169         // Treats empty input shape as scalar
170         if (dim_str.empty()) continue;
171         if (dim_str == "?") {
172           dims.push_back(-1);
173           continue;
174         }
175         int size;
176         TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
177         dims.push_back(size);
178       }
179       shapes_vector.push_back(dims);
180     }
181   }
182   return Status::OK();
183 }
184 
ParseNodeNames(absl::string_view names_str,std::vector<std::string> & names_vector)185 Status ParseNodeNames(absl::string_view names_str,
186                       std::vector<std::string>& names_vector) {
187   names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty());
188   return Status::OK();
189 }
190 
ParseNodeDataTypes(absl::string_view data_types_str,std::vector<std::string> & data_type_vector)191 Status ParseNodeDataTypes(absl::string_view data_types_str,
192                           std::vector<std::string>& data_type_vector) {
193   data_type_vector.clear();
194   if (!data_types_str.empty()) {
195     data_type_vector = absl::StrSplit(data_types_str, ',');
196   }
197   return Status::OK();
198 }
199 
200 }  // namespace tensorflow
201