• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
17 
18 #include <type_traits>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "llvm/ADT/Optional.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/framework/tensor_shape.pb.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 
ParseOutputArrayInfo(absl::string_view array_names,std::vector<string> * outputs)37 Status ParseOutputArrayInfo(absl::string_view array_names,
38                             std::vector<string>* outputs) {
39   TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs));
40   return Status::OK();
41 }
42 
ParseOutputArrayInfo(const std::vector<string> & output_names,std::vector<string> * outputs)43 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
44                             std::vector<string>* outputs) {
45   for (auto& output_name : output_names) {
46     if (output_name.empty()) continue;
47     outputs->push_back(output_name);
48   }
49   return Status::OK();
50 }
51 
ParseInputArrayInfo(absl::string_view array_names,absl::string_view data_types,absl::string_view shapes,GraphImportConfig::InputArrays * inputs)52 Status ParseInputArrayInfo(absl::string_view array_names,
53                            absl::string_view data_types,
54                            absl::string_view shapes,
55                            GraphImportConfig::InputArrays* inputs) {
56   std::vector<string> node_names;
57   std::vector<string> node_dtypes;
58   std::vector<llvm::Optional<std::vector<int>>> node_shapes;
59   TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
60   TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
61   TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
62   return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
63 }
64 
ParseInputArrayInfo(const std::vector<string> & node_names,const std::vector<string> & node_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & node_shapes,GraphImportConfig::InputArrays * inputs)65 Status ParseInputArrayInfo(
66     const std::vector<string>& node_names,
67     const std::vector<string>& node_dtypes,
68     const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
69     GraphImportConfig::InputArrays* inputs) {
70   std::vector<std::string> used_node_dtypes;
71   if (node_dtypes.empty()) {
72     // Mark all the node dtypes Invalid, so the importer can handle them by
73     // using the type from the graph.
74     used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID));
75   } else if (node_names.size() == node_dtypes.size()) {
76     for (const auto& dtype : node_dtypes) {
77       if (dtype.empty()) {
78         used_node_dtypes.push_back(DataType_Name(DT_INVALID));
79       } else if (dtype != DataType_Name(DT_INVALID)) {
80         used_node_dtypes.push_back(dtype);
81       } else {
82         return errors::FailedPrecondition(
83             "Use '' if want to use the type from graph.");
84       }
85     }
86   } else {
87     return errors::FailedPrecondition(absl::StrCat(
88         "Unmatched node array and data type numbers (#arrays ",
89         node_names.size(), ", #data_types ", node_dtypes.size(), ")"));
90   }
91 
92   if (!node_shapes.empty() && node_names.size() != node_shapes.size()) {
93     return errors::FailedPrecondition(absl::StrCat(
94         "Unmatched node array and shape numbers (#arrays ", node_names.size(),
95         ", #input_shapes ", node_shapes.size(), ")"));
96   }
97 
98   // StringMap doesn't support reserve else reserve input map size here.
99   for (int i = 0, end = node_names.size(); i < end; i++) {
100     auto& name = node_names[i];
101     if (name.empty()) continue;
102 
103     auto it_inserted_pair = inputs->insert({name, {}});
104     if (!it_inserted_pair.second)
105       return errors::FailedPrecondition(
106           absl::StrCat("tensor ", name, " is repeated in the arrays flag"));
107 
108     ArrayInfo& info = it_inserted_pair.first->second;
109     if (!DataType_Parse(used_node_dtypes[i], &info.imported_dtype)) {
110       return errors::FailedPrecondition(
111           absl::StrCat("Invalid node type '", node_dtypes[i], "'"));
112     }
113 
114     if (!node_shapes.empty()) {
115       if (!node_shapes[i].hasValue()) {
116         info.shape.set_unknown_rank(true);
117         continue;
118       }
119       for (auto& dim : node_shapes[i].getValue()) {
120         info.shape.add_dim()->set_size(dim);
121       }
122     }
123   }
124   return Status::OK();
125 }
126 
ParseNodeShapes(absl::string_view shapes_str,std::vector<llvm::Optional<std::vector<int>>> & shapes_vector)127 Status ParseNodeShapes(
128     absl::string_view shapes_str,
129     std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
130   shapes_vector.clear();
131   if (!shapes_str.empty()) {
132     std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
133     for (int i = 0; i < node_shapes_str.size(); i++) {
134       if (node_shapes_str[i] == "*") {
135         shapes_vector.push_back(llvm::None);
136         continue;
137       }
138       std::vector<int> dims;
139       for (const absl::string_view dim_str :
140            absl::StrSplit(node_shapes_str[i], ',')) {
141         // Treats empty input shape as scalar
142         if (dim_str.empty()) continue;
143         if (dim_str == "?") {
144           dims.push_back(-1);
145           continue;
146         }
147         int size;
148         TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
149         dims.push_back(size);
150       }
151       shapes_vector.push_back(dims);
152     }
153   }
154   return Status::OK();
155 }
156 
ParseNodeNames(absl::string_view names_str,std::vector<std::string> & names_vector)157 Status ParseNodeNames(absl::string_view names_str,
158                       std::vector<std::string>& names_vector) {
159   names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty());
160   return Status::OK();
161 }
162 
ParseNodeDataTypes(absl::string_view data_types_str,std::vector<std::string> & data_type_vector)163 Status ParseNodeDataTypes(absl::string_view data_types_str,
164                           std::vector<std::string>& data_type_vector) {
165   data_type_vector.clear();
166   if (!data_types_str.empty()) {
167     data_type_vector = absl::StrSplit(data_types_str, ',');
168   }
169   return Status::OK();
170 }
171 
172 }  // namespace tensorflow
173