1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/shape_refiner.h"
17
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/resource_variable_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/common_runtime/function_testlib.h"
22 #include "tensorflow/core/framework/common_shape_fns.h"
23 #include "tensorflow/core/framework/function_testlib.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/graph/testlib.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/public/version.h"
32
33 namespace tensorflow {
34
35 class ShapeRefinerTest : public ::testing::Test {
36 protected:
37 // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(shape_inference::DimensionHandle a,shape_inference::DimensionHandle b)38 bool SameHandle(shape_inference::DimensionHandle a,
39 shape_inference::DimensionHandle b) {
40 return a.SameHandle(b);
41 }
42
SameHandle(shape_inference::ShapeHandle a,shape_inference::ShapeHandle b)43 bool SameHandle(shape_inference::ShapeHandle a,
44 shape_inference::ShapeHandle b) {
45 return a.SameHandle(b);
46 }
47
48 // These give access to private functions of ShapeRefiner.
SameDefinedShape(shape_inference::InferenceContext * c,shape_inference::ShapeHandle s0,shape_inference::ShapeHandle s1)49 bool SameDefinedShape(shape_inference::InferenceContext* c,
50 shape_inference::ShapeHandle s0,
51 shape_inference::ShapeHandle s1) {
52 return ShapeRefiner::SameDefinedShape(c, s0, s1);
53 }
54
IsUpdatedShapesOrTypes(shape_inference::InferenceContext * c,const std::vector<shape_inference::ShapeAndType> & existing,const std::vector<shape_inference::ShapeAndType> & updated)55 bool IsUpdatedShapesOrTypes(
56 shape_inference::InferenceContext* c,
57 const std::vector<shape_inference::ShapeAndType>& existing,
58 const std::vector<shape_inference::ShapeAndType>& updated) {
59 return ShapeRefiner::IsUpdatedShapesOrTypes(c, existing, updated);
60 }
61
62 static constexpr int64_t kMaxTensorSize = ShapeRefiner::kMaxTensorSize;
63
TestStridedSlice(const PartialTensorShape & input_shape,int begin,int end,int stride,const char * expected,int begin_mask=0,int end_mask=0,int ellipsis_mask=0)64 void TestStridedSlice(const PartialTensorShape& input_shape, int begin,
65 int end, int stride, const char* expected,
66 int begin_mask = 0, int end_mask = 0,
67 int ellipsis_mask = 0) {
68 Scope root = Scope::DisabledShapeInferenceScope();
69 auto placeholder =
70 ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
71 auto input = ops::Shape(root, placeholder);
72 auto begin_op = ops::Const(root, {begin});
73 auto end_op = ops::Const(root, {end});
74 auto stride_op = ops::Const(root, {stride});
75 auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op,
76 ops::StridedSlice::BeginMask(begin_mask)
77 .EndMask(end_mask)
78 .EllipsisMask(ellipsis_mask));
79 Node* result;
80 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
81 .Input(slice.node())
82 .Finalize(root.graph(), &result));
83
84 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
85 TF_ASSERT_OK(m.AddNode(placeholder.node()));
86 TF_ASSERT_OK(m.AddNode(input.node()));
87 TF_ASSERT_OK(m.AddNode(begin_op.node()));
88 TF_ASSERT_OK(m.AddNode(end_op.node()));
89 TF_ASSERT_OK(m.AddNode(stride_op.node()));
90 TF_ASSERT_OK(m.AddNode(slice.node()));
91 TF_ASSERT_OK(m.AddNode(result));
92
93 shape_inference::InferenceContext* ctx = m.GetContext(result);
94 EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected);
95 }
96 };
97
98 namespace {
99
100 #define EXPECT_SHAPE(EXPECTED, M, OP, IDX) \
101 do { \
102 shape_inference::InferenceContext* ctx = M.GetContext(OP.node()); \
103 EXPECT_EQ(EXPECTED, ctx->DebugString(ctx->output(IDX))); \
104 } while (0);
105
106 #define EXPECT_RESOURCE_SINGLE_SHAPE(EXPECTED, M, OP, IDX) \
107 do { \
108 shape_inference::InferenceContext* ctx = M.GetContext(OP.node()); \
109 auto* v = ctx->output_handle_shapes_and_types(IDX); \
110 EXPECT_NE(v, nullptr); \
111 EXPECT_EQ(v->size(), 1); \
112 EXPECT_EQ(EXPECTED, ctx->DebugString((*v)[0].shape)); \
113 } while (0);
114
115 #define EXPECT_RESOURCE_SINGLE_TYPE(EXPECTED, M, OP, IDX) \
116 do { \
117 shape_inference::InferenceContext* ctx = M.GetContext(OP.node()); \
118 auto* v = ctx->output_handle_shapes_and_types(IDX); \
119 EXPECT_NE(v, nullptr); \
120 EXPECT_EQ(v->size(), 1); \
121 EXPECT_EQ(EXPECTED, (*v)[0].dtype); \
122 } while (0);
123
TEST_F(ShapeRefinerTest,Constant)124 TEST_F(ShapeRefinerTest, Constant) {
125 // Create a constant node and validate that adding it is successful
126 // and that its shape is correct.
127 Scope root = Scope::NewRootScope();
128 auto c = ops::Const(root, 42.0f);
129 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
130 TF_ASSERT_OK(m.AddNode(c.node()));
131
132 EXPECT_SHAPE("[]", m, c, 0);
133 }
134
TEST_F(ShapeRefinerTest,MatMul)135 TEST_F(ShapeRefinerTest, MatMul) {
136 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
137
138 Scope root = Scope::NewRootScope();
139 auto a = ops::Const(root, {{1.0f}, {2.0f}});
140 auto b = ops::Const(root, {{1.0f, 2.0f}});
141 auto mm = ops::MatMul(root, a, b);
142
143 TF_ASSERT_OK(m.AddNode(a.node()));
144 TF_ASSERT_OK(m.AddNode(b.node()));
145 TF_ASSERT_OK(m.AddNode(mm.node()));
146
147 EXPECT_SHAPE("[2,1]", m, a, 0);
148 EXPECT_SHAPE("[1,2]", m, b, 0);
149 EXPECT_SHAPE("[2,2]", m, mm, 0);
150 }
151
TEST_F(ShapeRefinerTest,BadShapes)152 TEST_F(ShapeRefinerTest, BadShapes) {
153 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
154 Scope root = Scope::NewRootScope();
155 auto a = ops::Const(root, {{1.0f}, {2.0f}});
156 auto b = ops::Const(root, {{1.0f}, {2.0f}});
157 auto mm = ops::MatMul(root, a, b);
158
159 TF_ASSERT_OK(m.AddNode(a.node()));
160 TF_ASSERT_OK(m.AddNode(b.node()));
161 // The shape of the inputs are not compatible, so we should expect
162 // an error.
163 Status s = m.AddNode(mm.node());
164 ASSERT_FALSE(s.ok());
165 ASSERT_TRUE(absl::StrContains(s.error_message(),
166 "Dimensions must be equal, but are 1 and 2"));
167 }
168
TEST_F(ShapeRefinerTest,SetShape)169 TEST_F(ShapeRefinerTest, SetShape) {
170 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
171
172 Scope root = Scope::NewRootScope();
173 auto a = ops::Placeholder(root, DT_FLOAT);
174
175 TF_ASSERT_OK(m.AddNode(a.node()));
176
177 auto ic = m.GetContext(a.node());
178 ASSERT_NE(nullptr, ic);
179 shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
180 TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
181 EXPECT_SHAPE("[2,?]", m, a, 0);
182
183 // Check that shapes are merged with the existing shape.
184 shape_inference::ShapeHandle h2 = ic->MakeShape({ic->UnknownDim(), 2});
185 TF_ASSERT_OK(m.SetShape(a.node(), 0, h2));
186 EXPECT_SHAPE("[2,2]", m, a, 0);
187
188 // Out of range.
189 ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
190 ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
191
192 auto b = ops::Const(root, {{1.0f}, {2.0f}});
193 // Forget to add node first.
194 ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
195
196 // Set an incompatible shape (3 vs 2)
197 h = ic->MakeShape({3, ic->UnknownDim()});
198 ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
199 }
200
201 namespace {
202
203 // An op with no shape function.
204 REGISTER_OP("TestOpWithNoShapeFn").Input("a: int32").Output("o: int32");
205
206 } // namespace
207
TEST_F(ShapeRefinerTest,MissingShapeInferenceFns)208 TEST_F(ShapeRefinerTest, MissingShapeInferenceFns) {
209 Scope root = Scope::NewRootScope();
210 auto a = ops::Const(root, 42);
211 Node* b;
212 TF_ASSERT_OK(NodeBuilder("b", "TestOpWithNoShapeFn")
213 .Input(a.node())
214 .Finalize(root.graph(), &b));
215 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
216 TF_ASSERT_OK(m.AddNode(a.node()));
217 EXPECT_FALSE(m.AddNode(b).ok());
218 m.set_require_shape_inference_fns(false);
219 TF_EXPECT_OK(m.AddNode(b));
220 }
221
TEST_F(ShapeRefinerTest,PropagateConstants)222 TEST_F(ShapeRefinerTest, PropagateConstants) {
223 // Reduction dimension is a variable, so we don't know its value.
224 // So the output shape value is unknown (though its rank is known).
225 {
226 Scope root = Scope::NewRootScope();
227 // 3x2 input
228 auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
229 // Reduce along unspecified dimension
230 auto dim = ops::Variable(root, {}, DT_INT32);
231
232 auto am = ops::ArgMax(root, input, dim);
233 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
234 TF_ASSERT_OK(m.AddNode(input.node()));
235 TF_ASSERT_OK(m.AddNode(dim.node()));
236 TF_ASSERT_OK(m.AddNode(am.node()));
237 EXPECT_SHAPE("[?]", m, am, 0);
238 }
239
240 // Constant is used as dimension, which can be materialized,
241 // so the shape function can be more precise about the output.
242 {
243 Scope root = Scope::NewRootScope();
244 // 3x2 input
245 auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
246 // Reduce along 2nd dimension
247 auto dim = ops::Const(root, 1);
248
249 auto am = ops::ArgMax(root, input, dim);
250 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
251 TF_ASSERT_OK(m.AddNode(input.node()));
252 TF_ASSERT_OK(m.AddNode(dim.node()));
253 TF_ASSERT_OK(m.AddNode(am.node()));
254 EXPECT_SHAPE("[3]", m, am, 0);
255 }
256
257 // Reduce along known first dimension.
258 {
259 Scope root = Scope::NewRootScope();
260 // 3x2 input
261 auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
262 // Reduce along 1st dimension
263 auto dim = ops::Const(root, 0);
264
265 auto am = ops::ArgMax(root, input, dim);
266 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
267 TF_ASSERT_OK(m.AddNode(input.node()));
268 TF_ASSERT_OK(m.AddNode(dim.node()));
269 TF_ASSERT_OK(m.AddNode(am.node()));
270 EXPECT_SHAPE("[2]", m, am, 0);
271 }
272 }
273
TEST_F(ShapeRefinerTest,ExtractConstantSubgraphMultiOutput)274 TEST_F(ShapeRefinerTest, ExtractConstantSubgraphMultiOutput) {
275 // Test when a node yields two outputs, one of which has a constant
276 // value that is small enough to be cached, and one which does not.
277 //
278 // ShapeVectorForAllElements nodes are used in here to call
279 // input_tensor from the shape function.
280 {
281 Scope root = Scope::NewRootScope();
282 auto small = ops::Const(root, {static_cast<int32>(1), TensorShape({1, 1})});
283 auto large = ops::Const(
284 root, {static_cast<int32>(2), TensorShape({4, kMaxTensorSize / 2})});
285 Node* multi;
286 TF_ASSERT_OK(NodeBuilder("MI", "MultiIdentity")
287 .Input(std::vector<NodeBuilder::NodeOut>{small.node(),
288 large.node()})
289 .Attr("N", 2)
290 .Finalize(root.graph(), &multi));
291
292 Node* shape_v;
293 TF_ASSERT_OK(NodeBuilder("Test", "ShapeVectorForAllElements")
294 .Input(multi, 0)
295 .Finalize(root.graph(), &shape_v));
296
297 auto add = ops::Add(root, Output(multi, 0), Output(multi, 1));
298 Node* shape_v2;
299 TF_ASSERT_OK(NodeBuilder("Test", "ShapeVectorForAllElements")
300 .Input(add.node())
301 .Finalize(root.graph(), &shape_v2));
302 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
303 TF_ASSERT_OK(m.AddNode(small.node()));
304 TF_ASSERT_OK(m.AddNode(large.node()));
305 TF_ASSERT_OK(m.AddNode(multi));
306 TF_ASSERT_OK(m.AddNode(shape_v));
307 TF_ASSERT_OK(m.AddNode(add.node()));
308 TF_ASSERT_OK(m.AddNode(shape_v2));
309
310 // The output shape is a vector of length equal to the result of the add.
311 // The add adds 1 and 2 together, and its output has kMaxTensorSize*2
312 // elements.
313 shape_inference::InferenceContext* ctx = m.GetContext(shape_v2);
314 EXPECT_EQ(strings::StrCat("[", kMaxTensorSize * 2 * 3, "]"),
315 ctx->DebugString(ctx->output(0)));
316 }
317 }
318
319 namespace {
320
321 // An op with a shape function whose outputs depend in a complex
322 // way on whether input tensors are available.
323 REGISTER_OP("TestOp")
324 .Input("a: float")
325 .Input("b: float")
326 .Output("o: float")
__anon684d22820402(shape_inference::InferenceContext* c) 327 .SetShapeFn([](shape_inference::InferenceContext* c) {
328 if (c->input_tensor(0)) {
329 if (c->input_tensor(1)) {
330 c->set_output(0, c->Matrix(10, 10));
331 return Status::OK();
332 }
333 return shape_inference::ScalarShape(c);
334 }
335 return shape_inference::UnknownShape(c);
336 });
337
338 } // namespace
339
TEST_F(ShapeRefinerTest,InputTensorDependencies)340 TEST_F(ShapeRefinerTest, InputTensorDependencies) {
341 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
342 Graph graph(OpRegistry::Global());
343 Node* node;
344
345 Tensor a(DT_FLOAT, TensorShape({}));
346 a.scalar<float>()() = 1.0;
347
348 Tensor b(DT_FLOAT, TensorShape({}));
349 b.scalar<float>()() = 2.0;
350
351 Node* input_a = test::graph::Constant(&graph, a);
352 Node* input_b = test::graph::Constant(&graph, b);
353 TF_ASSERT_OK(NodeBuilder("Test", "TestOp")
354 .Input(input_a)
355 .Input(input_b)
356 .Finalize(&graph, &node));
357
358 TF_ASSERT_OK(m.AddNode(input_a));
359 TF_ASSERT_OK(m.AddNode(input_b));
360 TF_ASSERT_OK(m.AddNode(node));
361 shape_inference::InferenceContext* ctx = m.GetContext(node);
362 EXPECT_EQ("[10,10]", ctx->DebugString(ctx->output(0)));
363 }
364
365 namespace {
366
367 // An op with a shape function that looks at its input tensor
368 // data and makes a Shape out of it.
369 REGISTER_OP("ShapeData")
370 .Input("a: int32")
371 .Output("o: int32")
__anon684d22820602(shape_inference::InferenceContext* c) 372 .SetShapeFn([](shape_inference::InferenceContext* c) {
373 const Tensor* shape_data = c->input_tensor(0);
374 if (shape_data == nullptr) {
375 return shape_inference::UnknownShape(c);
376 }
377
378 std::vector<shape_inference::DimensionHandle> dims;
379 dims.reserve(shape_data->NumElements());
380 for (int i = 0; i < shape_data->NumElements(); ++i) {
381 dims.emplace_back(c->MakeDim(shape_data->flat<int32>()(i)));
382 }
383
384 c->set_output(0, c->MakeShape(dims));
385 return Status::OK();
386 });
387
388 REGISTER_OP("ShapeDataInt64")
389 .Input("a: int64")
390 .Output("o: int64")
__anon684d22820702(shape_inference::InferenceContext* c) 391 .SetShapeFn([](shape_inference::InferenceContext* c) {
392 const Tensor* shape_data = c->input_tensor(0);
393 if (shape_data == nullptr) {
394 return shape_inference::UnknownShape(c);
395 }
396
397 std::vector<shape_inference::DimensionHandle> dims;
398 dims.reserve(shape_data->NumElements());
399 for (int i = 0; i < shape_data->NumElements(); ++i) {
400 dims.emplace_back(c->MakeDim(shape_data->flat<int64>()(i)));
401 }
402
403 c->set_output(0, c->MakeShape(dims));
404 return Status::OK();
405 });
406
407 // An op with a shape function that looks at its input tensor
408 // data and makes a rank 1 shape out of the sum of all input values.
409 REGISTER_OP("ShapeVectorForAllElements")
410 .Input("a: int32")
411 .Output("o: int32")
__anon684d22820802(shape_inference::InferenceContext* c) 412 .SetShapeFn([](shape_inference::InferenceContext* c) {
413 const Tensor* shape_data = c->input_tensor(0);
414 if (shape_data == nullptr) {
415 return shape_inference::UnknownShape(c);
416 }
417 int64_t total = 0;
418 for (int i = 0; i < shape_data->NumElements(); ++i) {
419 total += shape_data->flat<int32>()(i);
420 }
421
422 c->set_output(0, c->Vector(total));
423 return Status::OK();
424 });
425
426 REGISTER_OP("MultiIdentity")
427 .Input("a: N * int32")
428 .Output("o: N * int32")
429 .Attr("N: int >= 1")
__anon684d22820902(shape_inference::InferenceContext* c) 430 .SetShapeFn([](shape_inference::InferenceContext* c) {
431 for (int i = 0; i < c->num_inputs(); ++i) {
432 c->set_output(i, c->input(i));
433 }
434 return Status::OK();
435 });
436
437 class MultiIdentity : public OpKernel {
438 public:
MultiIdentity(OpKernelConstruction * c)439 explicit MultiIdentity(OpKernelConstruction* c) : OpKernel(c) {}
440
Compute(OpKernelContext * c)441 void Compute(OpKernelContext* c) override {
442 for (int i = 0; i < c->num_inputs(); ++i) {
443 c->set_output(i, c->input(i));
444 }
445 }
446 };
447 REGISTER_KERNEL_BUILDER(Name("MultiIdentity").Device(DEVICE_CPU),
448 MultiIdentity);
449
450 } // namespace
451
TEST_F(ShapeRefinerTest,PropagateShapeAcrossTensorContent)452 TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContent) {
453 Scope root = Scope::NewRootScope();
454
455 // Create variable 2x4 tensor.
456 auto input = ops::Variable(root, {2, 4}, DT_INT32);
457
458 // Shape is a vector of 2 elements (2,4)
459 auto shape = ops::Shape(root, input);
460
461 // Ones for indices of the slice. (get the 4).
462 auto ones = ops::Const(root, {1});
463
464 // Slice an element of the shape (4).
465 auto sliced = ops::Slice(root, shape, ones, ones);
466
467 Node* shape_data;
468 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
469 .Input(sliced.node())
470 .Finalize(root.graph(), &shape_data));
471
472 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
473 TF_ASSERT_OK(m.AddNode(ones.node()));
474 TF_ASSERT_OK(m.AddNode(input.node()));
475 TF_ASSERT_OK(m.AddNode(shape.node()));
476 TF_ASSERT_OK(m.AddNode(sliced.node()));
477 TF_ASSERT_OK(m.AddNode(shape_data));
478
479 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
480 EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
481 }
482
TEST_F(ShapeRefinerTest,PropagateShapeAcrossTensorContentInt64)483 TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) {
484 Scope root = Scope::NewRootScope();
485
486 // Create variable 2x4 tensor.
487 auto input = ops::Variable(
488 root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
489 DT_INT64);
490
491 // Shape is a vector of 2 elements (2,4)
492 auto attrs = ops::Shape::OutType(DT_INT64);
493 auto shape = ops::Shape(root, input, attrs);
494
495 // Ones for indices of the slice. (get the 4).
496 auto ones = ops::Const(root, {1});
497
498 // Slice an element of the shape (4).
499 auto sliced = ops::Slice(root, shape, ones, ones);
500
501 Node* shape_data;
502 TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
503 .Input(sliced.node())
504 .Finalize(root.graph(), &shape_data));
505
506 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
507 TF_ASSERT_OK(m.AddNode(ones.node()));
508 TF_ASSERT_OK(m.AddNode(input.node()));
509 TF_ASSERT_OK(m.AddNode(shape.node()));
510 TF_ASSERT_OK(m.AddNode(sliced.node()));
511 TF_ASSERT_OK(m.AddNode(shape_data));
512
513 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
514 EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
515 }
516
TEST_F(ShapeRefinerTest,PropagateShapeAcrossTensorContentInt32Overflow)517 TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) {
518 Scope root = Scope::NewRootScope();
519
520 // Create variable 2x4 tensor.
521 auto input = ops::Variable(
522 root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
523 DT_INT32);
524
525 // Shape is a vector of 2 elements (2,4)
526 auto shape = ops::Shape(root, input);
527
528 // Ones for indices of the slice. (get the 4).
529 auto ones = ops::Const(root, {1});
530
531 // Slice an element of the shape (4).
532 auto sliced = ops::Slice(root, shape, ones, ones);
533
534 Node* shape_data;
535 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
536 .Input(sliced.node())
537 .Finalize(root.graph(), &shape_data));
538
539 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
540 TF_ASSERT_OK(m.AddNode(ones.node()));
541 TF_ASSERT_OK(m.AddNode(input.node()));
542 TF_ASSERT_OK(m.AddNode(shape.node()));
543 TF_ASSERT_OK(m.AddNode(sliced.node()));
544
545 // Expect an error since there's an overflow.
546 EXPECT_FALSE(m.AddNode(shape_data).ok());
547 }
548
TEST_F(ShapeRefinerTest,PropagateRankAcrossTensorContent)549 TEST_F(ShapeRefinerTest, PropagateRankAcrossTensorContent) {
550 Scope root = Scope::NewRootScope();
551
552 // Create variable 2x4x3 tensor.
553 auto input = ops::Variable(root, {2, 4, 3}, DT_INT32);
554
555 // Rank 3.
556 auto rank = ops::Rank(root, input);
557
558 auto identity = ops::Identity(root, rank);
559
560 Node* shape_data;
561 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
562 .Input(identity.node())
563 .Finalize(root.graph(), &shape_data));
564
565 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
566 TF_ASSERT_OK(m.AddNode(input.node()));
567 TF_ASSERT_OK(m.AddNode(rank.node()));
568 TF_ASSERT_OK(m.AddNode(identity.node()));
569 TF_ASSERT_OK(m.AddNode(shape_data));
570
571 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
572 EXPECT_EQ("[3]", ctx->DebugString(ctx->output(0)));
573 }
574
TEST_F(ShapeRefinerTest,PropagateSizeAcrossTensorContent)575 TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContent) {
576 Scope root = Scope::NewRootScope();
577
578 // Create variable.
579 auto input = ops::Variable(root, {1, 2, 3, 4, 5}, DT_INT32);
580
581 // 5!.
582 auto size = ops::Size(root, input);
583
584 auto identity = ops::Identity(root, size);
585
586 Node* shape_data;
587 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
588 .Input(identity.node())
589 .Finalize(root.graph(), &shape_data));
590
591 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
592 TF_ASSERT_OK(m.AddNode(input.node()));
593 TF_ASSERT_OK(m.AddNode(size.node()));
594 TF_ASSERT_OK(m.AddNode(identity.node()));
595 TF_ASSERT_OK(m.AddNode(shape_data));
596
597 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
598 EXPECT_EQ("[120]", ctx->DebugString(ctx->output(0)));
599 }
600
TEST_F(ShapeRefinerTest,PropagateSizeAcrossTensorContentInt64)601 TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) {
602 Scope root = Scope::NewRootScope();
603
604 // Create variable.
605 auto input =
606 ops::Variable(root,
607 {1, 2, 3, 4, 5,
608 static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
609 DT_INT64);
610
611 // 5! * int32_max_value * 2.
612 auto attrs = ops::Size::OutType(DT_INT64);
613 auto size = ops::Size(root, input, attrs);
614
615 auto identity = ops::Identity(root, size);
616
617 Node* shape_data;
618 TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
619 .Input(identity.node())
620 .Finalize(root.graph(), &shape_data));
621
622 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
623 TF_ASSERT_OK(m.AddNode(input.node()));
624 TF_ASSERT_OK(m.AddNode(size.node()));
625 TF_ASSERT_OK(m.AddNode(identity.node()));
626 TF_ASSERT_OK(m.AddNode(shape_data));
627
628 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
629 EXPECT_EQ("[515396075280]", ctx->DebugString(ctx->output(0)));
630 }
631
TEST_F(ShapeRefinerTest,PropagateSizeAcrossTensorContentInt32Overflow)632 TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) {
633 Scope root = Scope::NewRootScope();
634
635 // Create variable.
636 auto input =
637 ops::Variable(root,
638 {1, 2, 3, 4, 5,
639 static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
640 DT_INT32);
641
642 // 5!.
643 auto size = ops::Size(root, input);
644
645 auto identity = ops::Identity(root, size);
646
647 Node* shape_data;
648 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
649 .Input(identity.node())
650 .Finalize(root.graph(), &shape_data));
651
652 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
653 TF_ASSERT_OK(m.AddNode(input.node()));
654 TF_ASSERT_OK(m.AddNode(size.node()));
655 TF_ASSERT_OK(m.AddNode(identity.node()));
656 EXPECT_FALSE(m.AddNode(shape_data).ok());
657 }
658
TEST_F(ShapeRefinerTest,PropagateShape)659 TEST_F(ShapeRefinerTest, PropagateShape) {
660 Scope root = Scope::NewRootScope();
661 // 3x2 input
662 auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
663
664 // Shape is a vector of 2 elements (3,2)
665 auto shape = ops::Shape(root, input);
666
667 Node* shape_data;
668 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
669 .Input(shape.node())
670 .Finalize(root.graph(), &shape_data));
671
672 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
673 TF_ASSERT_OK(m.AddNode(input.node()));
674 TF_ASSERT_OK(m.AddNode(shape.node()));
675 TF_ASSERT_OK(m.AddNode(shape_data));
676
677 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
678 EXPECT_EQ("[3,2]", ctx->DebugString(ctx->output(0)));
679 }
680
TEST_F(ShapeRefinerTest,PropagateSize)681 TEST_F(ShapeRefinerTest, PropagateSize) {
682 Scope root = Scope::NewRootScope();
683 // 3x2 input
684 auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
685
686 auto size = ops::Size(root, input);
687
688 Node* shape_data;
689 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
690 .Input(size.node())
691 .Finalize(root.graph(), &shape_data));
692
693 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
694 TF_ASSERT_OK(m.AddNode(input.node()));
695 TF_ASSERT_OK(m.AddNode(size.node()));
696 TF_ASSERT_OK(m.AddNode(shape_data));
697
698 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
699 EXPECT_EQ("[6]", ctx->DebugString(ctx->output(0)));
700 }
701
TEST_F(ShapeRefinerTest,PropagateRank)702 TEST_F(ShapeRefinerTest, PropagateRank) {
703 Scope root = Scope::NewRootScope();
704 // 3x2 input
705 auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
706
707 auto rank = ops::Rank(root, input);
708
709 Node* shape_data;
710 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
711 .Input(rank.node())
712 .Finalize(root.graph(), &shape_data));
713
714 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
715 TF_ASSERT_OK(m.AddNode(input.node()));
716 TF_ASSERT_OK(m.AddNode(rank.node()));
717 TF_ASSERT_OK(m.AddNode(shape_data));
718
719 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
720 EXPECT_EQ("[2]", ctx->DebugString(ctx->output(0)));
721 }
722
TEST_F(ShapeRefinerTest,PropagateRange)723 TEST_F(ShapeRefinerTest, PropagateRange) {
724 Scope root = Scope::NewRootScope();
725 auto begin = ops::Const(root, 1);
726 auto limit = ops::Const(root, 11);
727 auto delta = ops::Const(root, 3);
728 auto range = ops::Range(root, begin, limit, delta);
729
730 Node* shape_data;
731 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
732 .Input(range.node())
733 .Finalize(root.graph(), &shape_data));
734
735 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
736 TF_ASSERT_OK(m.AddNode(begin.node()));
737 TF_ASSERT_OK(m.AddNode(limit.node()));
738 TF_ASSERT_OK(m.AddNode(delta.node()));
739 TF_ASSERT_OK(m.AddNode(range.node()));
740 TF_ASSERT_OK(m.AddNode(shape_data));
741
742 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
743 EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0)));
744 }
745
746 // Make sure PlaceholderWithDefaults aren't treated as constants.
TEST_F(ShapeRefinerTest,NoPropagatePlaceholderWithDefault)747 TEST_F(ShapeRefinerTest, NoPropagatePlaceholderWithDefault) {
748 Scope root = Scope::NewRootScope();
749 auto constant = ops::Const<int>(root, 2);
750 auto placeholder =
751 ops::PlaceholderWithDefault(root, constant, PartialTensorShape());
752 Node* shape_data;
753 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
754 .Input(placeholder.node())
755 .Finalize(root.graph(), &shape_data));
756
757 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
758 TF_ASSERT_OK(m.AddNode(constant.node()));
759 TF_ASSERT_OK(m.AddNode(placeholder.node()));
760 TF_ASSERT_OK(m.AddNode(shape_data));
761 shape_inference::InferenceContext* ic = m.GetContext(shape_data);
762 EXPECT_EQ(ic->DebugString(ic->output(0)), "?");
763 }
764
TEST_F(ShapeRefinerTest,ConstantValueTwoInputsToSameNode)765 TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) {
766 Scope root = Scope::NewRootScope();
767 // This node is used as two inputs to 'range'.
768 auto begin_and_delta = ops::Const(root, 1);
769 auto limit = ops::Const(root, 4);
770 auto range = ops::Range(root, begin_and_delta, limit, begin_and_delta);
771
772 Node* shape_data;
773 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
774 .Input(range.node())
775 .Finalize(root.graph(), &shape_data));
776
777 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
778 TF_ASSERT_OK(m.AddNode(begin_and_delta.node()));
779 TF_ASSERT_OK(m.AddNode(limit.node()));
780 TF_ASSERT_OK(m.AddNode(range.node()));
781 TF_ASSERT_OK(m.AddNode(shape_data));
782
783 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
784 EXPECT_EQ("[1,2,3]", ctx->DebugString(ctx->output(0)));
785 }
786
787 // Creates a graph where 'begin' is attempted to be visited during
788 // constant value evaluation after having been processed once.
TEST_F(ShapeRefinerTest,ConstantValueVisitNodeTwice)789 TEST_F(ShapeRefinerTest, ConstantValueVisitNodeTwice) {
790 Scope root = Scope::NewRootScope();
791 auto begin = ops::Const(root, 1);
792 auto limit = ops::Const(root, 8);
793 auto delta = ops::Const(root, 3);
794
795 auto d1 = ops::Add(root, begin, limit); // 9
796 auto d2 = ops::Add(root, begin, delta); // 4
797 // Visiting flimit's children will visit 'begin' before 'd1'.
798 // It will then visit d1, whose child is 'begin'. That edge still
799 // must be visited.
800 auto flimit = ops::Sub(root, begin, d1); // 1-9=-8
801 auto fdelta = ops::Sub(root, begin, d2); // 1-4=-3
802 auto nl = ops::Abs(root, flimit); // 8
803 auto nd = ops::Abs(root, fdelta); // 3
804
805 auto range = ops::Range(root, begin, nl, nd);
806
807 Node* shape_data;
808 TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
809 .Input(range.node())
810 .Finalize(root.graph(), &shape_data));
811
812 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
813 TF_ASSERT_OK(m.AddNode(begin.node()));
814 TF_ASSERT_OK(m.AddNode(limit.node()));
815 TF_ASSERT_OK(m.AddNode(delta.node()));
816 TF_ASSERT_OK(m.AddNode(d1.node()));
817 TF_ASSERT_OK(m.AddNode(d2.node()));
818 TF_ASSERT_OK(m.AddNode(flimit.node()));
819 TF_ASSERT_OK(m.AddNode(fdelta.node()));
820 TF_ASSERT_OK(m.AddNode(nl.node()));
821 TF_ASSERT_OK(m.AddNode(nd.node()));
822 TF_ASSERT_OK(m.AddNode(range.node()));
823 TF_ASSERT_OK(m.AddNode(shape_data));
824
825 shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
826 EXPECT_EQ("[1,4,7]", ctx->DebugString(ctx->output(0)));
827 }
828
829 namespace {
830
TensorAsShapeShapeFn(shape_inference::InferenceContext * c)831 Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) {
832 shape_inference::ShapeHandle out;
833 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0 /* input_idx */, &out));
834 c->set_output(0, out);
835 return Status::OK();
836 }
837
838 // Register ops used by the ConstantValueAsShape* tests.
839
840 REGISTER_OP("TensorAsShapeInt32")
841 .Input("a: int32")
842 .Output("o: int32")
843 .SetShapeFn(TensorAsShapeShapeFn);
844
845 REGISTER_OP("TensorAsShapeInt64")
846 .Input("a: int64")
847 .Output("o: int64")
848 .SetShapeFn(TensorAsShapeShapeFn);
849
850 REGISTER_OP("NonConstScalarInt32")
851 .Output("o: int32")
852 .SetDoNotOptimize()
853 .SetShapeFn(shape_inference::ScalarShape);
854
855 REGISTER_OP("NonConstScalarInt64")
856 .Output("o: int64")
857 .SetDoNotOptimize()
858 .SetShapeFn(shape_inference::ScalarShape);
859
860 REGISTER_OP("WithEmptyVectorShape")
861 .Output("o: int32")
862 .SetDoNotOptimize()
__anon684d22820b02(shape_inference::InferenceContext* c) 863 .SetShapeFn([](shape_inference::InferenceContext* c) {
864 c->set_output(0, c->Vector(0));
865 return Status::OK();
866 });
867
868 REGISTER_OP("WithPartialShape")
869 .Output("o: int32")
870 .SetDoNotOptimize()
__anon684d22820c02(shape_inference::InferenceContext* c) 871 .SetShapeFn([](shape_inference::InferenceContext* c) {
872 c->set_output(
873 0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3,
874 shape_inference::InferenceContext::kUnknownDim, 5}));
875 return Status::OK();
876 });
877
878 REGISTER_OP("WithPartialShape2")
879 .Output("o: int32")
880 .SetDoNotOptimize()
__anon684d22820d02(shape_inference::InferenceContext* c) 881 .SetShapeFn([](shape_inference::InferenceContext* c) {
882 c->set_output(
883 0,
884 c->MakeShape({6, shape_inference::InferenceContext::kUnknownDim, 8}));
885 return Status::OK();
886 });
887
888 REGISTER_OP("WithUnknownShape")
889 .Output("o: int32")
890 .SetDoNotOptimize()
__anon684d22820e02(shape_inference::InferenceContext* c) 891 .SetShapeFn([](shape_inference::InferenceContext* c) {
892 c->set_output(0, c->UnknownShape());
893 return Status::OK();
894 });
895
896 } // namespace
897
TEST_F(ShapeRefinerTest,ConstantValueAsShape_EmptyVector)898 TEST_F(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) {
899 Scope root = Scope::NewRootScope();
900 Node* input;
901 TF_ASSERT_OK(
902 NodeBuilder("in", "WithEmptyVectorShape").Finalize(root.graph(), &input));
903 Node* result;
904 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
905 .Input(input)
906 .Finalize(root.graph(), &result));
907
908 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
909 TF_ASSERT_OK(m.AddNode(input));
910 TF_ASSERT_OK(m.AddNode(result));
911
912 shape_inference::InferenceContext* ctx = m.GetContext(result);
913 EXPECT_EQ("[]", ctx->DebugString(ctx->output(0)));
914 }
915
TEST_F(ShapeRefinerTest,ConstantValueAsShape_Shape)916 TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
917 for (int pass = 0; pass < 2; ++pass) {
918 Scope root = Scope::NewRootScope();
919 Node* input;
920 TF_ASSERT_OK(
921 NodeBuilder("in", pass == 0 ? "WithPartialShape" : "WithUnknownShape")
922 .Finalize(root.graph(), &input));
923 auto shape = ops::Shape(root, Output(input));
924 Node* result;
925 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
926 .Input(shape.node())
927 .Finalize(root.graph(), &result));
928
929 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
930 TF_ASSERT_OK(m.AddNode(input));
931 TF_ASSERT_OK(m.AddNode(shape.node()));
932 TF_ASSERT_OK(m.AddNode(result));
933
934 shape_inference::InferenceContext* ctx = m.GetContext(result);
935 if (pass == 0) {
936 EXPECT_EQ("[1,?,3,?,5]", ctx->DebugString(ctx->output(0)));
937 } else {
938 EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
939 }
940 }
941 }
942
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackInt32)943 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
944 Scope root = Scope::DisabledShapeInferenceScope();
945 Node* scalar_non_const;
946 TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
947 .Finalize(root.graph(), &scalar_non_const));
948
949 InputList inputs{
950 // clang-format off
951 Input(ops::Const<int32>(root, 10)),
952 Input(ops::Const<int32>(root, 20)),
953 Input(Output(scalar_non_const)),
954 Input(ops::Const<int32>(root, 40)),
955 }; // clang-format on
956 auto pack = ops::Stack(root, inputs);
957 TF_ASSERT_OK(root.status());
958
959 Node* result;
960 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
961 .Input(pack.node())
962 .Finalize(root.graph(), &result));
963
964 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
965 for (const auto& input : inputs) {
966 TF_ASSERT_OK(m.AddNode(input.node()));
967 }
968 TF_ASSERT_OK(m.AddNode(pack.node()));
969 TF_ASSERT_OK(m.AddNode(result));
970
971 shape_inference::InferenceContext* ctx = m.GetContext(result);
972 EXPECT_EQ("[10,20,?,40]", ctx->DebugString(ctx->output(0)));
973 }
974
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackInt64)975 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
976 Scope root = Scope::DisabledShapeInferenceScope();
977 Node* scalar_non_const;
978 TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
979 .Finalize(root.graph(), &scalar_non_const));
980
981 InputList inputs{
982 // clang-format off
983 Input(ops::Const<int64>(root, int64{10})),
984 Input(ops::Const<int64>(root, int64{20})),
985 Input(Output(scalar_non_const)),
986 Input(ops::Const<int64>(root, int64{1} << 40)),
987 }; // clang-format on
988 auto pack = ops::Stack(root, inputs);
989 TF_ASSERT_OK(root.status());
990
991 Node* result;
992 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
993 .Input(pack.node())
994 .Finalize(root.graph(), &result));
995
996 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
997 for (const auto& input : inputs) {
998 TF_ASSERT_OK(m.AddNode(input.node()));
999 }
1000 TF_ASSERT_OK(m.AddNode(pack.node()));
1001 TF_ASSERT_OK(m.AddNode(result));
1002
1003 shape_inference::InferenceContext* ctx = m.GetContext(result);
1004 EXPECT_EQ("[10,20,?,1099511627776]", ctx->DebugString(ctx->output(0)));
1005 }
1006
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackUnknownDim)1007 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) {
1008 Scope root = Scope::NewRootScope();
1009
1010 InputList inputs{
1011 Input(ops::Const<int64>(root, int64{10})),
1012 Input(ops::Const<int64>(root, int64{-1})),
1013 };
1014 auto pack = ops::Stack(root, inputs);
1015 TF_ASSERT_OK(root.status());
1016
1017 Node* result;
1018 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
1019 .Input(pack.node())
1020 .Finalize(root.graph(), &result));
1021
1022 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1023 for (const auto& input : inputs) {
1024 TF_ASSERT_OK(m.AddNode(input.node()));
1025 }
1026 TF_ASSERT_OK(m.AddNode(pack.node()));
1027 TF_ASSERT_OK(m.AddNode(result));
1028
1029 shape_inference::InferenceContext* ctx = m.GetContext(result);
1030 EXPECT_EQ("[10,?]", ctx->DebugString(ctx->output(0)));
1031 }
1032
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackInvalidInput)1033 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
1034 Scope root = Scope::NewRootScope();
1035
1036 // Inputs are length 2 vectors instead of scalars.
1037 InputList inputs{
1038 Input(ops::Const<int64>(root, {int64{10}, int64{20}})),
1039 Input(ops::Const<int64>(root, {int64{10}, int64{21}})),
1040 };
1041 auto pack = ops::Stack(root, inputs);
1042 TF_ASSERT_OK(root.status());
1043
1044 Node* result;
1045 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
1046 .Input(pack.node())
1047 .Finalize(root.graph(), &result));
1048
1049 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1050 for (const auto& input : inputs) {
1051 TF_ASSERT_OK(m.AddNode(input.node()));
1052 }
1053 TF_ASSERT_OK(m.AddNode(pack.node()));
1054 EXPECT_TRUE(
1055 absl::StrContains(m.AddNode(result).error_message(), "but is rank 2"));
1056 }
1057
TEST_F(ShapeRefinerTest,ConstantValueAsShape_Concat)1058 TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
1059 Scope root = Scope::DisabledShapeInferenceScope();
1060 Graph* g = root.graph();
1061 Node* partial_1;
1062 Node* partial_2;
1063 TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
1064 TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
1065 auto const_input = ops::Const(root, {9, 10, 11});
1066 OutputList concat_inputs{
1067 // clang-format off
1068 ops::Shape(root, Output(partial_1)),
1069 ops::Shape(root, Output(partial_2)),
1070 const_input,
1071 }; // clang-format on
1072 auto concat_dim = ops::Const(root, 0);
1073 auto concat = ops::Concat(root, concat_inputs, concat_dim);
1074 TF_ASSERT_OK(root.status());
1075
1076 Node* result;
1077 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1078 .Input(concat.node())
1079 .Finalize(g, &result));
1080
1081 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1082 TF_ASSERT_OK(m.AddNode(partial_1));
1083 TF_ASSERT_OK(m.AddNode(partial_2));
1084 for (const auto& o : concat_inputs) {
1085 TF_ASSERT_OK(m.AddNode(o.node()));
1086 }
1087 TF_ASSERT_OK(m.AddNode(concat_dim.node()));
1088 TF_ASSERT_OK(m.AddNode(concat.node()));
1089 TF_ASSERT_OK(m.AddNode(result));
1090
1091 shape_inference::InferenceContext* ctx = m.GetContext(result);
1092 EXPECT_EQ("[1,?,3,?,5,6,?,8,9,10,11]", ctx->DebugString(ctx->output(0)));
1093 }
1094
TEST_F(ShapeRefinerTest,ConstantValueAsShape_ConcatWithUnknown)1095 TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
1096 Scope root = Scope::DisabledShapeInferenceScope();
1097 Graph* g = root.graph();
1098 Node* scalar_non_const;
1099 TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
1100 .Finalize(root.graph(), &scalar_non_const));
1101
1102 Node* partial_1;
1103 Node* partial_2;
1104 Node* unknown;
1105 TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
1106 TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
1107 TF_ASSERT_OK(NodeBuilder("in", "WithUnknownShape").Finalize(g, &unknown));
1108 OutputList concat_inputs{
1109 // clang-format off
1110 ops::Shape(root, Output(partial_1)),
1111 ops::Shape(root, Output(partial_2)),
1112 ops::Shape(root, Output(unknown)),
1113 }; // clang-format on
1114 auto concat_dim = ops::Const(root, 0);
1115 auto concat = ops::Concat(root, concat_inputs, concat_dim);
1116 TF_ASSERT_OK(root.status());
1117
1118 Node* result;
1119 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1120 .Input(concat.node())
1121 .Finalize(g, &result));
1122
1123 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1124 TF_ASSERT_OK(m.AddNode(partial_1));
1125 TF_ASSERT_OK(m.AddNode(partial_2));
1126 TF_ASSERT_OK(m.AddNode(unknown));
1127 for (const auto& o : concat_inputs) {
1128 TF_ASSERT_OK(m.AddNode(o.node()));
1129 }
1130 TF_ASSERT_OK(m.AddNode(concat_dim.node()));
1131 TF_ASSERT_OK(m.AddNode(concat.node()));
1132 TF_ASSERT_OK(m.AddNode(result));
1133
1134 shape_inference::InferenceContext* ctx = m.GetContext(result);
1135 EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
1136 }
1137
TEST_F(ShapeRefinerTest,ConstantValueAsShape_ConcatInvalidDimValue)1138 TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
1139 Scope root = Scope::DisabledShapeInferenceScope();
1140 Graph* g = root.graph();
1141 Node* scalar_non_const;
1142 TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
1143 .Finalize(root.graph(), &scalar_non_const));
1144
1145 Node* partial_1;
1146 Node* partial_2;
1147 TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
1148 TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
1149 auto const_input = ops::Const(root, {9, -2, 11});
1150 OutputList concat_inputs{
1151 // clang-format off
1152 ops::Shape(root, Output(partial_1)),
1153 ops::Shape(root, Output(partial_2)),
1154 const_input,
1155 }; // clang-format on
1156 auto concat_dim = ops::Const(root, 0);
1157 auto concat = ops::Concat(root, concat_inputs, concat_dim);
1158 TF_ASSERT_OK(root.status());
1159
1160 Node* result;
1161 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1162 .Input(concat.node())
1163 .Finalize(g, &result));
1164
1165 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1166 TF_ASSERT_OK(m.AddNode(partial_1));
1167 TF_ASSERT_OK(m.AddNode(partial_2));
1168 for (const auto& o : concat_inputs) {
1169 TF_ASSERT_OK(m.AddNode(o.node()));
1170 }
1171 TF_ASSERT_OK(m.AddNode(concat_dim.node()));
1172 TF_ASSERT_OK(m.AddNode(concat.node()));
1173 EXPECT_EQ("Invalid value in tensor used for shape: -2",
1174 m.AddNode(result).error_message());
1175 }
1176
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSlice)1177 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) {
1178 TestStridedSlice(
1179 /*input_shape=*/{1, -1, 3, -1, 5},
1180 /*begin=*/2,
1181 /*end=*/5,
1182 /*stride=*/1,
1183 /*expected=*/"[3,?,5]");
1184 }
1185
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceNegativeStride)1186 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) {
1187 // clang-format off
1188 TestStridedSlice(
1189 /*input_shape=*/{1, -1, 3, -1, 5},
1190 /*begin=*/10,
1191 /*end=*/0,
1192 /*stride=*/-1,
1193 /*expected=*/"[5,?,3,?]");
1194 // clang-format on
1195 }
1196
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceMasks)1197 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) {
1198 TestStridedSlice(
1199 /*input_shape=*/{1, -1, 3, -1, 5},
1200 /*begin=*/3,
1201 /*end=*/4,
1202 /*stride=*/1,
1203 /*expected=*/"[1,?,3,?,5]",
1204 /*begin_mask=*/1,
1205 /*end_mask=*/1);
1206 }
1207
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceInvalidMask)1208 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) {
1209 TestStridedSlice(
1210 /*input_shape=*/{1, -1, 3},
1211 /*begin=*/2,
1212 /*end=*/3,
1213 /*stride=*/1,
1214 /*expected=*/"[?,?,?]",
1215 /*begin_mask=*/0,
1216 /*end_mask=*/0,
1217 /*ellipsis_mask=*/1);
1218 }
1219
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceMulti)1220 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) {
1221 Scope root = Scope::DisabledShapeInferenceScope();
1222 auto input = ops::Placeholder(root, DT_INT32);
1223 auto begin = ops::Const(root, {0, 0});
1224 auto end = ops::Const(root, {2, 2});
1225 auto stride = ops::Const(root, {1, 1});
1226 auto slice = ops::StridedSlice(root, input, begin, end, stride);
1227 Node* result;
1228 TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1229 .Input(slice.node())
1230 .Finalize(root.graph(), &result));
1231
1232 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1233 TF_ASSERT_OK(m.AddNode(input.node()));
1234 TF_ASSERT_OK(m.AddNode(begin.node()));
1235 TF_ASSERT_OK(m.AddNode(end.node()));
1236 TF_ASSERT_OK(m.AddNode(stride.node()));
1237 TF_ASSERT_OK(m.AddNode(slice.node()));
1238 TF_ASSERT_OK(m.AddNode(result));
1239
1240 shape_inference::InferenceContext* ctx = m.GetContext(result);
1241 EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?");
1242 }
1243
1244 namespace {
1245
1246 // Dummy op to test ShapeRefiner util functions
1247 REGISTER_OP("Dummy");
1248
1249 } // namespace
1250
TEST_F(ShapeRefinerTest,SameDefinedShape)1251 TEST_F(ShapeRefinerTest, SameDefinedShape) {
1252 Scope root = Scope::NewRootScope();
1253 Graph* g = root.graph();
1254 Node* test;
1255 TF_CHECK_OK(NodeBuilder("test", "Dummy").Finalize(g, &test));
1256 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1257 m.set_require_shape_inference_fns(false);
1258 TF_ASSERT_OK(m.AddNode(test));
1259 shape_inference::InferenceContext* ctx = m.GetContext(test);
1260
1261 auto unknown = ctx->UnknownShape();
1262 auto unknown_b = ctx->UnknownShape();
1263 auto s_1_2 = ctx->MakeShape({1, 2});
1264 auto s_1_2_b = ctx->MakeShape({1, 2});
1265 auto s_2_2 = ctx->MakeShape({2, 2});
1266 auto s_unknown_2 = ctx->MakeShape({-1, 2});
1267 auto s_unknown_2_b = ctx->MakeShape({-1, 2});
1268
1269 EXPECT_TRUE(SameDefinedShape(ctx, unknown, unknown));
1270 EXPECT_FALSE(SameDefinedShape(ctx, unknown, unknown_b));
1271 EXPECT_FALSE(SameDefinedShape(ctx, unknown, s_1_2));
1272 EXPECT_TRUE(SameDefinedShape(ctx, s_1_2, s_1_2_b));
1273 EXPECT_FALSE(SameDefinedShape(ctx, s_1_2, s_2_2));
1274 EXPECT_TRUE(SameDefinedShape(ctx, s_unknown_2, s_unknown_2));
1275 EXPECT_FALSE(SameDefinedShape(ctx, s_unknown_2, s_unknown_2_b));
1276 }
1277
TEST_F(ShapeRefinerTest,IsUpdatedShapesOrTypes)1278 TEST_F(ShapeRefinerTest, IsUpdatedShapesOrTypes) {
1279 Scope root = Scope::NewRootScope();
1280 Graph* g = root.graph();
1281 Node* test;
1282 TF_CHECK_OK(NodeBuilder("test", "Dummy").Finalize(g, &test));
1283 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1284 m.set_require_shape_inference_fns(false);
1285 TF_ASSERT_OK(m.AddNode(test));
1286 shape_inference::InferenceContext* ctx = m.GetContext(test);
1287
1288 shape_inference::ShapeHandle unknown = ctx->UnknownShape();
1289 std::vector<shape_inference::ShapeAndType> t0{
1290 {ctx->MakeShape({1, 2, 3}), DT_FLOAT},
1291 {unknown, DT_INVALID},
1292 {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1293
1294 std::vector<shape_inference::ShapeAndType> t1{
1295 {ctx->MakeShape({1, 2, 3}), DT_FLOAT},
1296 {unknown, DT_INVALID},
1297 {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1298
1299 std::vector<shape_inference::ShapeAndType> t2{
1300 {ctx->MakeShape({1, 2, 4}), DT_FLOAT},
1301 {ctx->UnknownShape(), DT_INVALID},
1302 {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1303
1304 std::vector<shape_inference::ShapeAndType> t3{
1305 {ctx->MakeShape({1, 2, 3}), DT_INT32},
1306 {ctx->UnknownShape(), DT_INVALID},
1307 {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1308
1309 EXPECT_FALSE(IsUpdatedShapesOrTypes(ctx, t0, t1));
1310
1311 // A shape has been modified
1312 EXPECT_TRUE(IsUpdatedShapesOrTypes(ctx, t0, t2));
1313
1314 // A type has been modified
1315 EXPECT_TRUE(IsUpdatedShapesOrTypes(ctx, t0, t3));
1316 }
1317
TEST_F(ShapeRefinerTest,IncrementalUpdates)1318 TEST_F(ShapeRefinerTest, IncrementalUpdates) {
1319 Scope root = Scope::NewRootScope();
1320 Graph* g = root.graph();
1321 Node* queue;
1322 TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2")
1323 .Attr("component_types", {DT_FLOAT})
1324 .Finalize(g, &queue));
1325 Node* dequeue;
1326 TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2")
1327 .Attr("component_types", {DT_FLOAT})
1328 .Input(queue)
1329 .Finalize(g, &dequeue));
1330 ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1331 TF_ASSERT_OK(m.AddNode(queue));
1332 TF_ASSERT_OK(m.AddNode(dequeue));
1333
1334 // At this point, the shapes of the dequeued tensor are unknown.
1335 shape_inference::InferenceContext* ctx = m.GetContext(dequeue);
1336 EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
1337
1338 // Inject a shape, and incrementally propagate it to the dequeue op.
1339 ctx = m.GetContext(queue);
1340 shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
1341 ctx->set_output_handle_shapes_and_types(
1342 0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
1343 bool refined = false;
1344 TF_ASSERT_OK(m.UpdateNode(dequeue, false /* relax */, &refined));
1345 EXPECT_TRUE(refined);
1346 ctx = m.GetContext(dequeue);
1347 EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0)));
1348
1349 // Inject another shape, but relax instead of merge.
1350 ctx = m.GetContext(queue);
1351 shp = ctx->MakeShape({2, 7});
1352 ctx->set_output_handle_shapes_and_types(
1353 0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
1354 refined = false;
1355 TF_ASSERT_OK(m.UpdateNode(dequeue, true /* relax */, &refined));
1356 EXPECT_TRUE(refined);
1357 ctx = m.GetContext(dequeue);
1358 EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
1359
1360 // Inject another partially unknown shape and attempt to relax it.
1361 ctx = m.GetContext(queue);
1362 shp = ctx->MakeShape({shape_inference::InferenceContext::kUnknownDim, 7});
1363 ctx->set_output_handle_shapes_and_types(
1364 0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
1365 refined = false;
1366 TF_ASSERT_OK(m.UpdateNode(dequeue, true /* relax */, &refined));
1367 EXPECT_TRUE(refined);
1368 ctx = m.GetContext(dequeue);
1369 EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
1370 EXPECT_TRUE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
1371
1372 // Inject a shape of the same handle and expect refined to not change.
1373 ctx = m.GetContext(queue);
1374 shape_inference::ShapeHandle shp2 = shp;
1375 ctx->set_output_handle_shapes_and_types(
1376 0, std::vector<shape_inference::ShapeAndType>{{shp2, DT_FLOAT}});
1377 refined = false;
1378 TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined));
1379 EXPECT_FALSE(refined);
1380 EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0)));
1381 }
1382
TestSimpleFunctionInference(bool enable_function_inference)1383 void TestSimpleFunctionInference(bool enable_function_inference) {
1384 FunctionDefLibrary f_lib_proto;
1385 *(f_lib_proto.add_function()) = test::function::XTimesTwo();
1386 FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1387
1388 Scope root = Scope::NewRootScope();
1389 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1390 auto x = ops::Const(root, {{1.0f, 2.0f}});
1391 auto x2 = test::function::Call(&root, "x2", "XTimesTwo", {x});
1392
1393 ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1394 if (enable_function_inference) {
1395 m.set_function_library_for_shape_inference(&f_lib);
1396 }
1397
1398 TF_ASSERT_OK(m.AddNode(x.node()));
1399 TF_ASSERT_OK(m.AddNode(x2.node()));
1400
1401 EXPECT_SHAPE("[1,2]", m, x, 0);
1402
1403 if (enable_function_inference) {
1404 EXPECT_SHAPE("[1,2]", m, x2, 0);
1405 } else {
1406 // Default inference behavior: functions output shapes are unknown.
1407 EXPECT_SHAPE("?", m, x2, 0);
1408 }
1409 }
1410
TEST_F(ShapeRefinerTest,SimpleFunctionShapeInference_Disabled)1411 TEST_F(ShapeRefinerTest, SimpleFunctionShapeInference_Disabled) {
1412 // Nesting flag doesn't matter, when function inference is disabled.
1413 TestSimpleFunctionInference(false /* enable_function_inference */);
1414 }
1415
TEST_F(ShapeRefinerTest,SimpleFunctionShapeInference)1416 TEST_F(ShapeRefinerTest, SimpleFunctionShapeInference) {
1417 TestSimpleFunctionInference(true /* enable_function_inference */);
1418 }
1419
TEST_F(ShapeRefinerTest,FunctionShapeInferenceFallback)1420 TEST_F(ShapeRefinerTest, FunctionShapeInferenceFallback) {
1421 // Test that function inference falls back to returning unknown shapes,
1422 // if the function lookup fails.
1423
1424 FunctionDefLibrary f_lib_proto;
1425 *(f_lib_proto.add_function()) = test::function::XTimesTwo();
1426 FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1427
1428 Scope root = Scope::NewRootScope();
1429 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1430 auto x = ops::Const(root, {{.0f, .0f}});
1431 auto x2 = test::function::Call(&root, "x2", "XTimesTwo", {x});
1432
1433 FunctionDefLibrary empty_f_lib_proto;
1434 FunctionLibraryDefinition empty_f_lib(OpRegistry::Global(),
1435 empty_f_lib_proto);
1436
1437 ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1438 m.set_function_library_for_shape_inference(&empty_f_lib);
1439
1440 TF_ASSERT_OK(m.AddNode(x.node()));
1441 TF_ASSERT_OK(m.AddNode(x2.node()));
1442
1443 EXPECT_SHAPE("[1,2]", m, x, 0);
1444
1445 // Default inference behavior: functions output shapes are unknown.
1446 EXPECT_SHAPE("?", m, x2, 0);
1447 }
1448
TEST_F(ShapeRefinerTest,ChainedFunctionShapeInferenceWithMultipleInputs)1449 TEST_F(ShapeRefinerTest, ChainedFunctionShapeInferenceWithMultipleInputs) {
1450 FunctionDefLibrary f_lib_proto;
1451 *(f_lib_proto.add_function()) = test::function::XTimesTwo();
1452 *(f_lib_proto.add_function()) = test::function::XTimesFour();
1453 *(f_lib_proto.add_function()) = test::function::XTimes16();
1454 *(f_lib_proto.add_function()) = test::function::WXPlusB();
1455 FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1456
1457 Scope root = Scope::NewRootScope();
1458 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1459 auto w = ops::Const(root, {{.0f}, {.0f}, {.0f}});
1460 auto x = ops::Const(root, {{.0f, .0f, .0f}});
1461 auto b = ops::Const(root, {{.0f}});
1462
1463 auto wxplusb = test::function::Call(&root, "wxplusb", "WXPlusB", {w, x, b});
1464 auto wxplusb16 =
1465 test::function::Call(&root, "wxplusb16", "XTimes16", {wxplusb});
1466
1467 ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1468 m.set_function_library_for_shape_inference(&f_lib);
1469
1470 TF_ASSERT_OK(m.AddNode(w.node()));
1471 TF_ASSERT_OK(m.AddNode(x.node()));
1472 TF_ASSERT_OK(m.AddNode(b.node()));
1473 TF_ASSERT_OK(m.AddNode(wxplusb.node()));
1474 TF_ASSERT_OK(m.AddNode(wxplusb16.node()));
1475
1476 EXPECT_SHAPE("[3,1]", m, w, 0);
1477 EXPECT_SHAPE("[1,3]", m, x, 0);
1478 EXPECT_SHAPE("[1,1]", m, b, 0);
1479 EXPECT_SHAPE("[3,3]", m, wxplusb, 0);
1480 EXPECT_SHAPE("[3,3]", m, wxplusb16, 0);
1481 }
1482
TEST_F(ShapeRefinerTest,FunctionShapeInferenceWorksForResourceHandles)1483 TEST_F(ShapeRefinerTest, FunctionShapeInferenceWorksForResourceHandles) {
1484 FunctionDefLibrary f_lib_proto;
1485 *(f_lib_proto.add_function()) = test::function::Swap();
1486
1487 FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1488
1489 Scope root = Scope::NewRootScope().ExitOnError();
1490 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1491
1492 auto x1 = ops::VarHandleOp(root, DataType::DT_FLOAT, TensorShape({128, 256}));
1493 auto x2 = ops::VarHandleOp(root, DataType::DT_DOUBLE, TensorShape({1024}));
1494 auto swap = test::function::Call(&root, "swap", "Swap", {x1, x2});
1495
1496 EXPECT_EQ(swap.node()->num_outputs(), 2);
1497
1498 ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1499 m.set_function_library_for_shape_inference(&f_lib);
1500
1501 TF_ASSERT_OK(m.AddNode(x1.node()));
1502 TF_ASSERT_OK(m.AddNode(x2.node()));
1503 TF_ASSERT_OK(m.AddNode(swap.node()));
1504
1505 EXPECT_EQ(m.GetContext(swap.node())->num_outputs(), 2);
1506
1507 EXPECT_RESOURCE_SINGLE_SHAPE("[128,256]", m, x1, 0);
1508 EXPECT_RESOURCE_SINGLE_SHAPE("[1024]", m, x2, 0);
1509 EXPECT_RESOURCE_SINGLE_SHAPE("[1024]", m, swap, 0);
1510 EXPECT_RESOURCE_SINGLE_SHAPE("[128,256]", m, swap, 1);
1511 EXPECT_RESOURCE_SINGLE_TYPE(DataType::DT_DOUBLE, m, swap, 0);
1512 EXPECT_RESOURCE_SINGLE_TYPE(DataType::DT_FLOAT, m, swap, 1);
1513 }
1514
1515 } // namespace
1516 } // namespace tensorflow
1517