• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference_testutil.h"
16 
17 #include "tensorflow/core/framework/node_def_util.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/core/lib/strings/numbers.h"
21 #include "tensorflow/core/lib/strings/scanner.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 
24 namespace tensorflow {
25 namespace shape_inference {
26 
27 using errors::Unknown;
28 
InferShapes(ShapeInferenceTestOp op,const string & ins,const string & expected_outs)29 Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
30                                            const string& ins,
31                                            const string& expected_outs) {
32   const OpRegistrationData* op_reg_data;
33   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
34 
35   std::vector<string> ins_v = str_util::Split(ins, ';');
36 
37   InferenceContext::ShapeManager manager;
38   std::vector<ShapeHandle> in_shapes;
39   for (const string& spec : ins_v) {
40     ShapeHandle shape;
41     TF_RETURN_IF_ERROR(MakeShapeFromString(&manager, spec, &shape));
42     in_shapes.push_back(shape);
43   }
44 
45   std::vector<std::unique_ptr<std::vector<shape_inference::ShapeAndType>>>
46       input_resource_handle_shapes_and_types;
47   for (const auto p : op.input_resource_handle_shapes_and_types) {
48     if (p == nullptr) {
49       input_resource_handle_shapes_and_types.push_back(nullptr);
50     } else {
51       std::unique_ptr<std::vector<ShapeAndType>> v(
52           new std::vector<ShapeAndType>());
53       for (const auto& shape_and_type : *p) {
54         ShapeHandle shape;
55         TF_RETURN_IF_ERROR(
56             MakeShapeFromString(&manager, shape_and_type.first, &shape));
57         v->emplace_back(shape, shape_and_type.second);
58       }
59       input_resource_handle_shapes_and_types.emplace_back(v.release());
60     }
61   }
62   shape_inference::InferenceContext c(
63       op.graph_def_version, op.node_def, op_reg_data->op_def, in_shapes,
64       op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
65   TF_RETURN_IF_ERROR(c.construction_status());
66   if (op_reg_data->shape_inference_fn == nullptr) {
67     return errors::InvalidArgument(
68         "No shape inference function exists for op '", op.name,
69         "', did you forget to define it?");
70   }
71 
72   TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
73 
74   const int num_outputs = c.num_outputs();
75 
76   if (expected_outs == "e") {
77     return Unknown("Shape inference should have returned error");
78   }
79 
80   // Verify the output shape.
81   std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';');
82   if (num_outputs != expected_outs_v.size()) {
83     return Unknown("The expected output string lists the wrong number of ",
84                    "outputs. It lists ", expected_outs_v.size(),
85                    " but should list ", num_outputs);
86   }
87   for (int i = 0; i < num_outputs; ++i) {
88     StringPiece expected(expected_outs_v[i]);
89     shape_inference::ShapeHandle out = c.output(i);
90 
91     string err_prefix = strings::StrCat("Output ", i);
92     string err_suffix =
93         strings::StrCat(". Output shape was ", c.DebugString(out));
94 
95     int in_index = -1;
96     for (int i = 0; i < c.num_inputs(); ++i) {
97       if (c.input(i).SameHandle(out)) {
98         in_index = i;
99       }
100     }
101 
102     if (absl::StartsWith(expected, "in")) {
103       if (in_index == -1) {
104         return Unknown(err_prefix,
105                        " should have matched an input shape by "
106                        "handle, but matched no input shape. This means the ",
107                        "shape function was expected to pass an input "
108                        "ShapeHandle through for this output, but did not",
109                        err_suffix);
110       }
111       auto v = str_util::Split(expected, '|');
112       if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) ==
113           v.end()) {
114         return Unknown(
115             err_prefix, " matched input ", in_index,
116             " by handle, but should have matched one of (", expected,
117             ") instead. This means the shape function passed the ShapeHandle ",
118             "for input ", in_index,
119             " to the output, but should have passed a different input ",
120             "ShapeHandle through", err_suffix);
121       }
122       continue;
123     }
124     if (in_index != -1) {
125       return Unknown(err_prefix, " matched input ", in_index,
126                      " by ShapeHandle, but was expected to not match an input ",
127                      "shape by handle", err_suffix);
128     }
129     if (expected == "?") {
130       if (c.RankKnown(out)) {
131         return Unknown(err_prefix, " expected to be unknown", err_suffix);
132       }
133       continue;
134     }
135 
136     // Verify the dimensions.
137     CHECK(absl::StartsWith(expected, "[") && str_util::EndsWith(expected, "]"))
138         << expected;
139     expected.remove_prefix(1);
140     expected.remove_suffix(1);
141 
142     // Split expected as a dimension.
143     auto expected_dims = str_util::Split(expected, ',');
144     if (!c.RankKnown(out)) {
145       return Unknown(err_prefix, " expected rank ", expected_dims.size(),
146                      " but was ?", err_suffix);
147     }
148     if (c.Rank(out) != expected_dims.size()) {
149       return Unknown(err_prefix, " expected rank ", expected_dims.size(),
150                      " but was ", c.Rank(out), err_suffix);
151     }
152     for (int j = 0; j < expected_dims.size(); ++j) {
153       err_prefix = strings::StrCat("Output dim ", i, ",", j);
154       StringPiece expected_dim(expected_dims[j]);
155       DimensionHandle out_dim = c.Dim(out, j);
156 
157       std::pair<int, int> in_dim_idx(-1, -1);
158       for (int i = 0; i < c.num_inputs(); ++i) {
159         auto in = c.input(i);
160         for (int j = 0; j < c.Rank(in); ++j) {
161           if (c.Dim(in, j).SameHandle(out_dim)) {
162             in_dim_idx = std::make_pair(i, j);
163           }
164         }
165       }
166 
167       if (expected_dim == "?") {
168         if (in_dim_idx.first != -1) {
169           return Unknown(err_prefix,
170                          " expected to be an unknown but matched input d",
171                          in_dim_idx.first, "_", in_dim_idx.second,
172                          ". The shape function passed through ",
173                          "a DimensionHandle from an input instead of making ",
174                          "a new unknown dimension", err_suffix);
175         } else if (c.ValueKnown(out_dim)) {
176           return Unknown(err_prefix, " expected to be unknown but was ",
177                          c.Value(out_dim), err_suffix);
178         }
179       } else if (absl::StartsWith(expected_dim, "d")) {
180         // Compare the dimension values.
181         auto v = str_util::Split(expected_dim, '|');
182         if (in_dim_idx.first == -1) {
183           return Unknown(
184               err_prefix, " was expected to match the dimension of an input, ",
185               "but did not match any input dimension. The shape ",
186               "function was expected to pass through a ",
187               "DimensionHandle for an input, but did not", err_suffix);
188         }
189         if (std::find(v.begin(), v.end(),
190                       strings::StrCat("d", in_dim_idx.first, "_",
191                                       in_dim_idx.second)) == v.end()) {
192           return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_",
193                          in_dim_idx.second,
194                          ", but should have matched one of (", expected_dim,
195                          "). The shape function passed through "
196                          "the DimensionHandle for an input, but ",
197                          "was expected to pass a different one", err_suffix);
198         }
199       } else {
200         // Parse it as a value.
201         int64 value = -1;
202         if (!strings::safe_strto64(expected_dim, &value)) {
203           return Unknown(err_prefix, ": the expected dimension value '",
204                          expected_dim, "' failed to parse as int64",
205                          err_suffix);
206         }
207         if (in_dim_idx.first != -1) {
208           return Unknown(  //
209               err_prefix, " expected to be ", value, " but matched input d",
210               in_dim_idx.first, "_", in_dim_idx.second,
211               ". The shape function was not expected to pass a DimensionHandle "
212               "from the input to the output, but did. Note that even if the "
213               "passed through output has the same dimension value as the "
214               "expected value, this is considered a failure for the test; "
215               "switch to using d#_# syntax if passing through the "
216               "DimensionHandle should be the expected behavior",
217               err_suffix);
218         } else if (value != c.Value(out_dim)) {
219           return Unknown(err_prefix, " expected to be ", value, " but was ",
220                          c.DebugString(out_dim), err_suffix);
221         }
222       }
223     }
224   }
225   return Status::OK();
226 }
227 
228 // static
MakeShapeFromString(InferenceContext::ShapeManager * manager,const string & spec,ShapeHandle * output)229 Status ShapeInferenceTestutil::MakeShapeFromString(
230     InferenceContext::ShapeManager* manager, const string& spec,
231     ShapeHandle* output) {
232   if (spec == "?") {
233     *output = manager->UnknownShape();
234     return Status::OK();
235   }
236 
237   std::vector<DimensionHandle> dims;
238   strings::Scanner scanner(spec);
239   scanner.OneLiteral("[");
240   while (scanner.Peek() != ']') {
241     if (scanner.Peek() == '?') {
242       scanner.OneLiteral("?");
243       dims.push_back(manager->MakeDim(InferenceContext::kUnknownDim));
244     } else {
245       scanner.RestartCapture().Many(strings::Scanner::DIGIT);
246       StringPiece match;
247       int64 dim_size = 0;
248 
249       if (!scanner.GetResult(nullptr, &match) ||
250           !strings::safe_strto64(match, &dim_size)) {
251         return errors::InvalidArgument("Could not parse number in ", spec);
252       }
253 
254       dims.push_back(manager->MakeDim(dim_size));
255     }
256 
257     if (scanner.Peek() == ',') {
258       scanner.OneLiteral(",");
259     } else if (scanner.Peek() != ']') {
260       return errors::InvalidArgument(
261           "Invalid input spec (] not found in dim shape): ", spec);
262     }
263   }
264   if (!scanner.OneLiteral("]").Eos().GetResult()) {
265     return errors::InvalidArgument("Malformed shape spec: did not end in ']'.");
266   }
267   *output = manager->MakeShape(dims);
268 
269   return Status::OK();
270 }
271 
272 }  // namespace shape_inference
273 }  // namespace tensorflow
274