1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference_testutil.h"
16
17 #include "tensorflow/core/framework/node_def_util.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/core/lib/strings/numbers.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23
24 namespace tensorflow {
25 namespace shape_inference {
26
27 using errors::Unknown;
28
InferShapes(ShapeInferenceTestOp op,const string & ins,const string & expected_outs)29 Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
30 const string& ins,
31 const string& expected_outs) {
32 const OpRegistrationData* op_reg_data;
33 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
34
35 std::vector<string> ins_v = str_util::Split(ins, ';');
36
37 InferenceContext::ShapeManager manager;
38 std::vector<ShapeHandle> in_shapes;
39 for (const string& spec : ins_v) {
40 ShapeHandle shape;
41 TF_RETURN_IF_ERROR(MakeShapeFromString(&manager, spec, &shape));
42 in_shapes.push_back(shape);
43 }
44
45 std::vector<std::unique_ptr<std::vector<shape_inference::ShapeAndType>>>
46 input_resource_handle_shapes_and_types;
47 for (const auto p : op.input_resource_handle_shapes_and_types) {
48 if (p == nullptr) {
49 input_resource_handle_shapes_and_types.push_back(nullptr);
50 } else {
51 std::unique_ptr<std::vector<ShapeAndType>> v(
52 new std::vector<ShapeAndType>());
53 for (const auto& shape_and_type : *p) {
54 ShapeHandle shape;
55 TF_RETURN_IF_ERROR(
56 MakeShapeFromString(&manager, shape_and_type.first, &shape));
57 v->emplace_back(shape, shape_and_type.second);
58 }
59 input_resource_handle_shapes_and_types.emplace_back(v.release());
60 }
61 }
62 shape_inference::InferenceContext c(
63 op.graph_def_version, op.node_def, op_reg_data->op_def, in_shapes,
64 op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
65 TF_RETURN_IF_ERROR(c.construction_status());
66 if (op_reg_data->shape_inference_fn == nullptr) {
67 return errors::InvalidArgument(
68 "No shape inference function exists for op '", op.name,
69 "', did you forget to define it?");
70 }
71
72 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
73
74 const int num_outputs = c.num_outputs();
75
76 if (expected_outs == "e") {
77 return Unknown("Shape inference should have returned error");
78 }
79
80 // Verify the output shape.
81 std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';');
82 if (num_outputs != expected_outs_v.size()) {
83 return Unknown("The expected output string lists the wrong number of ",
84 "outputs. It lists ", expected_outs_v.size(),
85 " but should list ", num_outputs);
86 }
87 for (int i = 0; i < num_outputs; ++i) {
88 StringPiece expected(expected_outs_v[i]);
89 shape_inference::ShapeHandle out = c.output(i);
90
91 string err_prefix = strings::StrCat("Output ", i);
92 string err_suffix =
93 strings::StrCat(". Output shape was ", c.DebugString(out));
94
95 int in_index = -1;
96 for (int i = 0; i < c.num_inputs(); ++i) {
97 if (c.input(i).SameHandle(out)) {
98 in_index = i;
99 }
100 }
101
102 if (absl::StartsWith(expected, "in")) {
103 if (in_index == -1) {
104 return Unknown(err_prefix,
105 " should have matched an input shape by "
106 "handle, but matched no input shape. This means the ",
107 "shape function was expected to pass an input "
108 "ShapeHandle through for this output, but did not",
109 err_suffix);
110 }
111 auto v = str_util::Split(expected, '|');
112 if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) ==
113 v.end()) {
114 return Unknown(
115 err_prefix, " matched input ", in_index,
116 " by handle, but should have matched one of (", expected,
117 ") instead. This means the shape function passed the ShapeHandle ",
118 "for input ", in_index,
119 " to the output, but should have passed a different input ",
120 "ShapeHandle through", err_suffix);
121 }
122 continue;
123 }
124 if (in_index != -1) {
125 return Unknown(err_prefix, " matched input ", in_index,
126 " by ShapeHandle, but was expected to not match an input ",
127 "shape by handle", err_suffix);
128 }
129 if (expected == "?") {
130 if (c.RankKnown(out)) {
131 return Unknown(err_prefix, " expected to be unknown", err_suffix);
132 }
133 continue;
134 }
135
136 // Verify the dimensions.
137 CHECK(absl::StartsWith(expected, "[") && str_util::EndsWith(expected, "]"))
138 << expected;
139 expected.remove_prefix(1);
140 expected.remove_suffix(1);
141
142 // Split expected as a dimension.
143 auto expected_dims = str_util::Split(expected, ',');
144 if (!c.RankKnown(out)) {
145 return Unknown(err_prefix, " expected rank ", expected_dims.size(),
146 " but was ?", err_suffix);
147 }
148 if (c.Rank(out) != expected_dims.size()) {
149 return Unknown(err_prefix, " expected rank ", expected_dims.size(),
150 " but was ", c.Rank(out), err_suffix);
151 }
152 for (int j = 0; j < expected_dims.size(); ++j) {
153 err_prefix = strings::StrCat("Output dim ", i, ",", j);
154 StringPiece expected_dim(expected_dims[j]);
155 DimensionHandle out_dim = c.Dim(out, j);
156
157 std::pair<int, int> in_dim_idx(-1, -1);
158 for (int i = 0; i < c.num_inputs(); ++i) {
159 auto in = c.input(i);
160 for (int j = 0; j < c.Rank(in); ++j) {
161 if (c.Dim(in, j).SameHandle(out_dim)) {
162 in_dim_idx = std::make_pair(i, j);
163 }
164 }
165 }
166
167 if (expected_dim == "?") {
168 if (in_dim_idx.first != -1) {
169 return Unknown(err_prefix,
170 " expected to be an unknown but matched input d",
171 in_dim_idx.first, "_", in_dim_idx.second,
172 ". The shape function passed through ",
173 "a DimensionHandle from an input instead of making ",
174 "a new unknown dimension", err_suffix);
175 } else if (c.ValueKnown(out_dim)) {
176 return Unknown(err_prefix, " expected to be unknown but was ",
177 c.Value(out_dim), err_suffix);
178 }
179 } else if (absl::StartsWith(expected_dim, "d")) {
180 // Compare the dimension values.
181 auto v = str_util::Split(expected_dim, '|');
182 if (in_dim_idx.first == -1) {
183 return Unknown(
184 err_prefix, " was expected to match the dimension of an input, ",
185 "but did not match any input dimension. The shape ",
186 "function was expected to pass through a ",
187 "DimensionHandle for an input, but did not", err_suffix);
188 }
189 if (std::find(v.begin(), v.end(),
190 strings::StrCat("d", in_dim_idx.first, "_",
191 in_dim_idx.second)) == v.end()) {
192 return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_",
193 in_dim_idx.second,
194 ", but should have matched one of (", expected_dim,
195 "). The shape function passed through "
196 "the DimensionHandle for an input, but ",
197 "was expected to pass a different one", err_suffix);
198 }
199 } else {
200 // Parse it as a value.
201 int64 value = -1;
202 if (!strings::safe_strto64(expected_dim, &value)) {
203 return Unknown(err_prefix, ": the expected dimension value '",
204 expected_dim, "' failed to parse as int64",
205 err_suffix);
206 }
207 if (in_dim_idx.first != -1) {
208 return Unknown( //
209 err_prefix, " expected to be ", value, " but matched input d",
210 in_dim_idx.first, "_", in_dim_idx.second,
211 ". The shape function was not expected to pass a DimensionHandle "
212 "from the input to the output, but did. Note that even if the "
213 "passed through output has the same dimension value as the "
214 "expected value, this is considered a failure for the test; "
215 "switch to using d#_# syntax if passing through the "
216 "DimensionHandle should be the expected behavior",
217 err_suffix);
218 } else if (value != c.Value(out_dim)) {
219 return Unknown(err_prefix, " expected to be ", value, " but was ",
220 c.DebugString(out_dim), err_suffix);
221 }
222 }
223 }
224 }
225 return Status::OK();
226 }
227
228 // static
MakeShapeFromString(InferenceContext::ShapeManager * manager,const string & spec,ShapeHandle * output)229 Status ShapeInferenceTestutil::MakeShapeFromString(
230 InferenceContext::ShapeManager* manager, const string& spec,
231 ShapeHandle* output) {
232 if (spec == "?") {
233 *output = manager->UnknownShape();
234 return Status::OK();
235 }
236
237 std::vector<DimensionHandle> dims;
238 strings::Scanner scanner(spec);
239 scanner.OneLiteral("[");
240 while (scanner.Peek() != ']') {
241 if (scanner.Peek() == '?') {
242 scanner.OneLiteral("?");
243 dims.push_back(manager->MakeDim(InferenceContext::kUnknownDim));
244 } else {
245 scanner.RestartCapture().Many(strings::Scanner::DIGIT);
246 StringPiece match;
247 int64 dim_size = 0;
248
249 if (!scanner.GetResult(nullptr, &match) ||
250 !strings::safe_strto64(match, &dim_size)) {
251 return errors::InvalidArgument("Could not parse number in ", spec);
252 }
253
254 dims.push_back(manager->MakeDim(dim_size));
255 }
256
257 if (scanner.Peek() == ',') {
258 scanner.OneLiteral(",");
259 } else if (scanner.Peek() != ']') {
260 return errors::InvalidArgument(
261 "Invalid input spec (] not found in dim shape): ", spec);
262 }
263 }
264 if (!scanner.OneLiteral("]").Eos().GetResult()) {
265 return errors::InvalidArgument("Malformed shape spec: did not end in ']'.");
266 }
267 *output = manager->MakeShape(dims);
268
269 return Status::OK();
270 }
271
272 } // namespace shape_inference
273 } // namespace tensorflow
274