• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference_testutil.h"
16 
17 #include "tensorflow/core/framework/node_def_util.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/core/lib/strings/numbers.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 
24 namespace tensorflow {
25 namespace shape_inference {
26 
27 using errors::Unknown;
28 
InferShapes(ShapeInferenceTestOp op,const string & ins,const string & expected_outs)29 Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
30                                            const string& ins,
31                                            const string& expected_outs) {
32   const OpRegistrationData* op_reg_data;
33   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
34 
35   std::vector<string> ins_v = str_util::Split(ins, ';');
36   std::unique_ptr<const NodeDef> new_node_def;
37 
38   InferenceContext::ShapeManager manager;
39   std::vector<ShapeHandle> in_shapes;
40   for (const string& spec : ins_v) {
41     ShapeHandle shape;
42     TF_RETURN_IF_ERROR(MakeShapeFromString(&manager, spec, &shape));
43     in_shapes.push_back(shape);
44   }
45 
46   std::vector<std::unique_ptr<std::vector<shape_inference::ShapeAndType>>>
47       input_resource_handle_shapes_and_types;
48   for (const auto p : op.input_resource_handle_shapes_and_types) {
49     if (p == nullptr) {
50       input_resource_handle_shapes_and_types.push_back(nullptr);
51     } else {
52       std::unique_ptr<std::vector<ShapeAndType>> v(
53           new std::vector<ShapeAndType>());
54       for (const auto& shape_and_type : *p) {
55         ShapeHandle shape;
56         TF_RETURN_IF_ERROR(
57             MakeShapeFromString(&manager, shape_and_type.first, &shape));
58         v->emplace_back(shape, shape_and_type.second);
59       }
60       input_resource_handle_shapes_and_types.emplace_back(v.release());
61     }
62   }
63   shape_inference::InferenceContext c(
64       op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes,
65       op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
66   TF_RETURN_IF_ERROR(c.construction_status());
67   if (op_reg_data->shape_inference_fn == nullptr) {
68     return errors::InvalidArgument(
69         "No shape inference function exists for op '", op.name,
70         "', did you forget to define it?");
71   }
72 
73   TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
74 
75   const int num_outputs = c.num_outputs();
76 
77   if (expected_outs == "e") {
78     return Unknown("Shape inference should have returned error");
79   }
80 
81   // Verify the output shape.
82   std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';');
83   if (num_outputs != expected_outs_v.size()) {
84     return Unknown("The expected output string lists the wrong number of ",
85                    "outputs. It lists ", expected_outs_v.size(),
86                    " but should list ", num_outputs);
87   }
88   for (int i = 0; i < num_outputs; ++i) {
89     StringPiece expected(expected_outs_v[i]);
90     shape_inference::ShapeHandle out = c.output(i);
91 
92     string err_prefix = strings::StrCat("Output ", i);
93     string err_suffix =
94         strings::StrCat(". Output shape was ", c.DebugString(out));
95 
96     int in_index = -1;
97     for (int i = 0; i < c.num_inputs(); ++i) {
98       if (c.input(i).SameHandle(out)) {
99         in_index = i;
100       }
101     }
102 
103     if (str_util::StartsWith(expected, "in")) {
104       if (in_index == -1) {
105         return Unknown(err_prefix,
106                        " should have matched an input shape by "
107                        "handle, but matched no input shape. This means the ",
108                        "shape function was expected to pass an input "
109                        "ShapeHandle through for this output, but did not",
110                        err_suffix);
111       }
112       auto v = str_util::Split(expected, '|');
113       if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) ==
114           v.end()) {
115         return Unknown(
116             err_prefix, " matched input ", in_index,
117             " by handle, but should have matched one of (", expected,
118             ") instead. This means the shape function passed the ShapeHandle ",
119             "for input ", in_index,
120             " to the output, but should have passed a different input ",
121             "ShapeHandle through", err_suffix);
122       }
123       continue;
124     }
125     if (in_index != -1) {
126       return Unknown(err_prefix, " matched input ", in_index,
127                      " by ShapeHandle, but was expected to not match an input ",
128                      "shape by handle", err_suffix);
129     }
130     if (expected == "?") {
131       if (c.RankKnown(out)) {
132         return Unknown(err_prefix, " expected to be unknown", err_suffix);
133       }
134       continue;
135     }
136 
137     // Verify the dimensions.
138     CHECK(str_util::StartsWith(expected, "[") &&
139           str_util::EndsWith(expected, "]"))
140         << expected;
141     expected.remove_prefix(1);
142     expected.remove_suffix(1);
143 
144     // Split expected as a dimension.
145     auto expected_dims = str_util::Split(expected, ',');
146     if (!c.RankKnown(out)) {
147       return Unknown(err_prefix, " expected rank ", expected_dims.size(),
148                      " but was ?", err_suffix);
149     }
150     if (c.Rank(out) != expected_dims.size()) {
151       return Unknown(err_prefix, " expected rank ", expected_dims.size(),
152                      " but was ", c.Rank(out), err_suffix);
153     }
154     for (int j = 0; j < expected_dims.size(); ++j) {
155       err_prefix = strings::StrCat("Output dim ", i, ",", j);
156       StringPiece expected_dim(expected_dims[j]);
157       DimensionHandle out_dim = c.Dim(out, j);
158 
159       std::pair<int, int> in_dim_idx(-1, -1);
160       for (int i = 0; i < c.num_inputs(); ++i) {
161         auto in = c.input(i);
162         for (int j = 0; j < c.Rank(in); ++j) {
163           if (c.Dim(in, j).SameHandle(out_dim)) {
164             in_dim_idx = std::make_pair(i, j);
165           }
166         }
167       }
168 
169       if (expected_dim == "?") {
170         if (in_dim_idx.first != -1) {
171           return Unknown(err_prefix,
172                          " expected to be an unknown but matched input d",
173                          in_dim_idx.first, "_", in_dim_idx.second,
174                          ". The shape function passed through ",
175                          "a DimensionHandle from an input instead of making ",
176                          "a new unknown dimension", err_suffix);
177         } else if (c.ValueKnown(out_dim)) {
178           return Unknown(err_prefix, " expected to be unknown but was ",
179                          c.Value(out_dim), err_suffix);
180         }
181       } else if (str_util::StartsWith(expected_dim, "d")) {
182         // Compare the dimension values.
183         auto v = str_util::Split(expected_dim, '|');
184         if (in_dim_idx.first == -1) {
185           return Unknown(
186               err_prefix, " was expected to match the dimension of an input, ",
187               "but did not match any input dimension. The shape ",
188               "function was expected to pass through a ",
189               "DimensionHandle for an input, but did not", err_suffix);
190         }
191         if (std::find(v.begin(), v.end(),
192                       strings::StrCat("d", in_dim_idx.first, "_",
193                                       in_dim_idx.second)) == v.end()) {
194           return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_",
195                          in_dim_idx.second,
196                          ", but should have matched one of (", expected_dim,
197                          "). The shape function passed through "
198                          "the DimensionHandle for an input, but ",
199                          "was expected to pass a different one", err_suffix);
200         }
201       } else {
202         // Parse it as a value.
203         int64 value = -1;
204         if (!strings::safe_strto64(expected_dim, &value)) {
205           return Unknown(err_prefix, ": the expected dimension value '",
206                          expected_dim, "' failed to parse as int64",
207                          err_suffix);
208         }
209         if (in_dim_idx.first != -1) {
210           return Unknown(  //
211               err_prefix, " expected to be ", value, " but matched input d",
212               in_dim_idx.first, "_", in_dim_idx.second,
213               ". The shape function was not expected to pass a DimensionHandle "
214               "from the input to the output, but did. Note that even if the "
215               "passed through output has the same dimension value as the "
216               "expected value, this is considered a failure for the test; "
217               "switch to using d#_# syntax if passing through the "
218               "DimensionHandle should be the expected behavior",
219               err_suffix);
220         } else if (value != c.Value(out_dim)) {
221           return Unknown(err_prefix, " expected to be ", value, " but was ",
222                          c.DebugString(out_dim), err_suffix);
223         }
224       }
225     }
226   }
227   return Status::OK();
228 }
229 
230 // static
MakeShapeFromString(InferenceContext::ShapeManager * manager,const string & spec,ShapeHandle * output)231 Status ShapeInferenceTestutil::MakeShapeFromString(
232     InferenceContext::ShapeManager* manager, const string& spec,
233     ShapeHandle* output) {
234   if (spec == "?") {
235     *output = manager->UnknownShape();
236     return Status::OK();
237   }
238 
239   std::vector<DimensionHandle> dims;
240   strings::Scanner scanner(spec);
241   scanner.OneLiteral("[");
242   while (scanner.Peek() != ']') {
243     if (scanner.Peek() == '?') {
244       scanner.OneLiteral("?");
245       dims.push_back(manager->MakeDim(InferenceContext::kUnknownDim));
246     } else {
247       scanner.RestartCapture().Many(strings::Scanner::DIGIT);
248       StringPiece match;
249       int64 dim_size = 0;
250 
251       if (!scanner.GetResult(nullptr, &match) ||
252           !strings::safe_strto64(match, &dim_size)) {
253         return errors::InvalidArgument("Could not parse number in ", spec);
254       }
255 
256       dims.push_back(manager->MakeDim(dim_size));
257     }
258 
259     if (scanner.Peek() == ',') {
260       scanner.OneLiteral(",");
261     } else if (scanner.Peek() != ']') {
262       return errors::InvalidArgument(
263           "Invalid input spec (] not found in dim shape): ", spec);
264     }
265   }
266   if (!scanner.OneLiteral("]").Eos().GetResult()) {
267     return errors::InvalidArgument("Malformed shape spec: did not end in ']'.");
268   }
269   *output = manager->MakeShape(dims);
270 
271   return Status::OK();
272 }
273 
274 }  // namespace shape_inference
275 }  // namespace tensorflow
276