• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/shape_refiner.h"
17 
18 #include "tensorflow/cc/framework/scope.h"
19 #include "tensorflow/cc/ops/resource_variable_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/common_runtime/function_testlib.h"
22 #include "tensorflow/core/framework/common_shape_fns.h"
23 #include "tensorflow/core/framework/function_testlib.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/graph/testlib.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/public/version.h"
32 
33 namespace tensorflow {
34 
35 class ShapeRefinerTest : public ::testing::Test {
36  protected:
37   // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(shape_inference::DimensionHandle a,shape_inference::DimensionHandle b)38   bool SameHandle(shape_inference::DimensionHandle a,
39                   shape_inference::DimensionHandle b) {
40     return a.SameHandle(b);
41   }
42 
SameHandle(shape_inference::ShapeHandle a,shape_inference::ShapeHandle b)43   bool SameHandle(shape_inference::ShapeHandle a,
44                   shape_inference::ShapeHandle b) {
45     return a.SameHandle(b);
46   }
47 
48   // These give access to private functions of ShapeRefiner.
SameDefinedShape(shape_inference::InferenceContext * c,shape_inference::ShapeHandle s0,shape_inference::ShapeHandle s1)49   bool SameDefinedShape(shape_inference::InferenceContext* c,
50                         shape_inference::ShapeHandle s0,
51                         shape_inference::ShapeHandle s1) {
52     return ShapeRefiner::SameDefinedShape(c, s0, s1);
53   }
54 
IsUpdatedShapesOrTypes(shape_inference::InferenceContext * c,const std::vector<shape_inference::ShapeAndType> & existing,const std::vector<shape_inference::ShapeAndType> & updated)55   bool IsUpdatedShapesOrTypes(
56       shape_inference::InferenceContext* c,
57       const std::vector<shape_inference::ShapeAndType>& existing,
58       const std::vector<shape_inference::ShapeAndType>& updated) {
59     return ShapeRefiner::IsUpdatedShapesOrTypes(c, existing, updated);
60   }
61 
62   static constexpr int64_t kMaxTensorSize = ShapeRefiner::kMaxTensorSize;
63 
TestStridedSlice(const PartialTensorShape & input_shape,int begin,int end,int stride,const char * expected,int begin_mask=0,int end_mask=0,int ellipsis_mask=0)64   void TestStridedSlice(const PartialTensorShape& input_shape, int begin,
65                         int end, int stride, const char* expected,
66                         int begin_mask = 0, int end_mask = 0,
67                         int ellipsis_mask = 0) {
68     Scope root = Scope::DisabledShapeInferenceScope();
69     auto placeholder =
70         ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
71     auto input = ops::Shape(root, placeholder);
72     auto begin_op = ops::Const(root, {begin});
73     auto end_op = ops::Const(root, {end});
74     auto stride_op = ops::Const(root, {stride});
75     auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op,
76                                    ops::StridedSlice::BeginMask(begin_mask)
77                                        .EndMask(end_mask)
78                                        .EllipsisMask(ellipsis_mask));
79     Node* result;
80     TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
81                      .Input(slice.node())
82                      .Finalize(root.graph(), &result));
83 
84     ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
85     TF_ASSERT_OK(m.AddNode(placeholder.node()));
86     TF_ASSERT_OK(m.AddNode(input.node()));
87     TF_ASSERT_OK(m.AddNode(begin_op.node()));
88     TF_ASSERT_OK(m.AddNode(end_op.node()));
89     TF_ASSERT_OK(m.AddNode(stride_op.node()));
90     TF_ASSERT_OK(m.AddNode(slice.node()));
91     TF_ASSERT_OK(m.AddNode(result));
92 
93     shape_inference::InferenceContext* ctx = m.GetContext(result);
94     EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected);
95   }
96 };
97 
98 namespace {
99 
100 #define EXPECT_SHAPE(EXPECTED, M, OP, IDX)                            \
101   do {                                                                \
102     shape_inference::InferenceContext* ctx = M.GetContext(OP.node()); \
103     EXPECT_EQ(EXPECTED, ctx->DebugString(ctx->output(IDX)));          \
104   } while (0);
105 
106 #define EXPECT_RESOURCE_SINGLE_SHAPE(EXPECTED, M, OP, IDX)            \
107   do {                                                                \
108     shape_inference::InferenceContext* ctx = M.GetContext(OP.node()); \
109     auto* v = ctx->output_handle_shapes_and_types(IDX);               \
110     EXPECT_NE(v, nullptr);                                            \
111     EXPECT_EQ(v->size(), 1);                                          \
112     EXPECT_EQ(EXPECTED, ctx->DebugString((*v)[0].shape));             \
113   } while (0);
114 
115 #define EXPECT_RESOURCE_SINGLE_TYPE(EXPECTED, M, OP, IDX)             \
116   do {                                                                \
117     shape_inference::InferenceContext* ctx = M.GetContext(OP.node()); \
118     auto* v = ctx->output_handle_shapes_and_types(IDX);               \
119     EXPECT_NE(v, nullptr);                                            \
120     EXPECT_EQ(v->size(), 1);                                          \
121     EXPECT_EQ(EXPECTED, (*v)[0].dtype);                               \
122   } while (0);
123 
TEST_F(ShapeRefinerTest,Constant)124 TEST_F(ShapeRefinerTest, Constant) {
125   // Create a constant node and validate that adding it is successful
126   // and that its shape is correct.
127   Scope root = Scope::NewRootScope();
128   auto c = ops::Const(root, 42.0f);
129   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
130   TF_ASSERT_OK(m.AddNode(c.node()));
131 
132   EXPECT_SHAPE("[]", m, c, 0);
133 }
134 
TEST_F(ShapeRefinerTest,MatMul)135 TEST_F(ShapeRefinerTest, MatMul) {
136   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
137 
138   Scope root = Scope::NewRootScope();
139   auto a = ops::Const(root, {{1.0f}, {2.0f}});
140   auto b = ops::Const(root, {{1.0f, 2.0f}});
141   auto mm = ops::MatMul(root, a, b);
142 
143   TF_ASSERT_OK(m.AddNode(a.node()));
144   TF_ASSERT_OK(m.AddNode(b.node()));
145   TF_ASSERT_OK(m.AddNode(mm.node()));
146 
147   EXPECT_SHAPE("[2,1]", m, a, 0);
148   EXPECT_SHAPE("[1,2]", m, b, 0);
149   EXPECT_SHAPE("[2,2]", m, mm, 0);
150 }
151 
TEST_F(ShapeRefinerTest,BadShapes)152 TEST_F(ShapeRefinerTest, BadShapes) {
153   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
154   Scope root = Scope::NewRootScope();
155   auto a = ops::Const(root, {{1.0f}, {2.0f}});
156   auto b = ops::Const(root, {{1.0f}, {2.0f}});
157   auto mm = ops::MatMul(root, a, b);
158 
159   TF_ASSERT_OK(m.AddNode(a.node()));
160   TF_ASSERT_OK(m.AddNode(b.node()));
161   // The shape of the inputs are not compatible, so we should expect
162   // an error.
163   Status s = m.AddNode(mm.node());
164   ASSERT_FALSE(s.ok());
165   ASSERT_TRUE(absl::StrContains(s.error_message(),
166                                 "Dimensions must be equal, but are 1 and 2"));
167 }
168 
TEST_F(ShapeRefinerTest,SetShape)169 TEST_F(ShapeRefinerTest, SetShape) {
170   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
171 
172   Scope root = Scope::NewRootScope();
173   auto a = ops::Placeholder(root, DT_FLOAT);
174 
175   TF_ASSERT_OK(m.AddNode(a.node()));
176 
177   auto ic = m.GetContext(a.node());
178   ASSERT_NE(nullptr, ic);
179   shape_inference::ShapeHandle h = ic->MakeShape({2, ic->UnknownDim()});
180   TF_ASSERT_OK(m.SetShape(a.node(), 0, h));
181   EXPECT_SHAPE("[2,?]", m, a, 0);
182 
183   // Check that shapes are merged with the existing shape.
184   shape_inference::ShapeHandle h2 = ic->MakeShape({ic->UnknownDim(), 2});
185   TF_ASSERT_OK(m.SetShape(a.node(), 0, h2));
186   EXPECT_SHAPE("[2,2]", m, a, 0);
187 
188   // Out of range.
189   ASSERT_FALSE(m.SetShape(a.node(), 1, h).ok());
190   ASSERT_FALSE(m.SetShape(a.node(), -1, h).ok());
191 
192   auto b = ops::Const(root, {{1.0f}, {2.0f}});
193   // Forget to add node first.
194   ASSERT_FALSE(m.SetShape(b.node(), 0, h).ok());
195 
196   // Set an incompatible shape (3 vs 2)
197   h = ic->MakeShape({3, ic->UnknownDim()});
198   ASSERT_FALSE(m.SetShape(a.node(), 0, h).ok());
199 }
200 
201 namespace {
202 
203 // An op with no shape function.
204 REGISTER_OP("TestOpWithNoShapeFn").Input("a: int32").Output("o: int32");
205 
206 }  // namespace
207 
TEST_F(ShapeRefinerTest,MissingShapeInferenceFns)208 TEST_F(ShapeRefinerTest, MissingShapeInferenceFns) {
209   Scope root = Scope::NewRootScope();
210   auto a = ops::Const(root, 42);
211   Node* b;
212   TF_ASSERT_OK(NodeBuilder("b", "TestOpWithNoShapeFn")
213                    .Input(a.node())
214                    .Finalize(root.graph(), &b));
215   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
216   TF_ASSERT_OK(m.AddNode(a.node()));
217   EXPECT_FALSE(m.AddNode(b).ok());
218   m.set_require_shape_inference_fns(false);
219   TF_EXPECT_OK(m.AddNode(b));
220 }
221 
TEST_F(ShapeRefinerTest,PropagateConstants)222 TEST_F(ShapeRefinerTest, PropagateConstants) {
223   // Reduction dimension is a variable, so we don't know its value.
224   // So the output shape value is unknown (though its rank is known).
225   {
226     Scope root = Scope::NewRootScope();
227     // 3x2 input
228     auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
229     // Reduce along unspecified dimension
230     auto dim = ops::Variable(root, {}, DT_INT32);
231 
232     auto am = ops::ArgMax(root, input, dim);
233     ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
234     TF_ASSERT_OK(m.AddNode(input.node()));
235     TF_ASSERT_OK(m.AddNode(dim.node()));
236     TF_ASSERT_OK(m.AddNode(am.node()));
237     EXPECT_SHAPE("[?]", m, am, 0);
238   }
239 
240   // Constant is used as dimension, which can be materialized,
241   // so the shape function can be more precise about the output.
242   {
243     Scope root = Scope::NewRootScope();
244     // 3x2 input
245     auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
246     // Reduce along 2nd dimension
247     auto dim = ops::Const(root, 1);
248 
249     auto am = ops::ArgMax(root, input, dim);
250     ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
251     TF_ASSERT_OK(m.AddNode(input.node()));
252     TF_ASSERT_OK(m.AddNode(dim.node()));
253     TF_ASSERT_OK(m.AddNode(am.node()));
254     EXPECT_SHAPE("[3]", m, am, 0);
255   }
256 
257   // Reduce along known first dimension.
258   {
259     Scope root = Scope::NewRootScope();
260     // 3x2 input
261     auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
262     // Reduce along 1st dimension
263     auto dim = ops::Const(root, 0);
264 
265     auto am = ops::ArgMax(root, input, dim);
266     ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
267     TF_ASSERT_OK(m.AddNode(input.node()));
268     TF_ASSERT_OK(m.AddNode(dim.node()));
269     TF_ASSERT_OK(m.AddNode(am.node()));
270     EXPECT_SHAPE("[2]", m, am, 0);
271   }
272 }
273 
TEST_F(ShapeRefinerTest,ExtractConstantSubgraphMultiOutput)274 TEST_F(ShapeRefinerTest, ExtractConstantSubgraphMultiOutput) {
275   // Test when a node yields two outputs, one of which has a constant
276   // value that is small enough to be cached, and one which does not.
277   //
278   // ShapeVectorForAllElements nodes are used in here to call
279   // input_tensor from the shape function.
280   {
281     Scope root = Scope::NewRootScope();
282     auto small = ops::Const(root, {static_cast<int32>(1), TensorShape({1, 1})});
283     auto large = ops::Const(
284         root, {static_cast<int32>(2), TensorShape({4, kMaxTensorSize / 2})});
285     Node* multi;
286     TF_ASSERT_OK(NodeBuilder("MI", "MultiIdentity")
287                      .Input(std::vector<NodeBuilder::NodeOut>{small.node(),
288                                                               large.node()})
289                      .Attr("N", 2)
290                      .Finalize(root.graph(), &multi));
291 
292     Node* shape_v;
293     TF_ASSERT_OK(NodeBuilder("Test", "ShapeVectorForAllElements")
294                      .Input(multi, 0)
295                      .Finalize(root.graph(), &shape_v));
296 
297     auto add = ops::Add(root, Output(multi, 0), Output(multi, 1));
298     Node* shape_v2;
299     TF_ASSERT_OK(NodeBuilder("Test", "ShapeVectorForAllElements")
300                      .Input(add.node())
301                      .Finalize(root.graph(), &shape_v2));
302     ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
303     TF_ASSERT_OK(m.AddNode(small.node()));
304     TF_ASSERT_OK(m.AddNode(large.node()));
305     TF_ASSERT_OK(m.AddNode(multi));
306     TF_ASSERT_OK(m.AddNode(shape_v));
307     TF_ASSERT_OK(m.AddNode(add.node()));
308     TF_ASSERT_OK(m.AddNode(shape_v2));
309 
310     // The output shape is a vector of length equal to the result of the add.
311     // The add adds 1 and 2 together, and its output has kMaxTensorSize*2
312     // elements.
313     shape_inference::InferenceContext* ctx = m.GetContext(shape_v2);
314     EXPECT_EQ(strings::StrCat("[", kMaxTensorSize * 2 * 3, "]"),
315               ctx->DebugString(ctx->output(0)));
316   }
317 }
318 
319 namespace {
320 
321 // An op with a shape function whose outputs depend in a complex
322 // way on whether input tensors are available.
323 REGISTER_OP("TestOp")
324     .Input("a: float")
325     .Input("b: float")
326     .Output("o: float")
__anon684d22820402(shape_inference::InferenceContext* c) 327     .SetShapeFn([](shape_inference::InferenceContext* c) {
328       if (c->input_tensor(0)) {
329         if (c->input_tensor(1)) {
330           c->set_output(0, c->Matrix(10, 10));
331           return Status::OK();
332         }
333         return shape_inference::ScalarShape(c);
334       }
335       return shape_inference::UnknownShape(c);
336     });
337 
338 }  // namespace
339 
TEST_F(ShapeRefinerTest,InputTensorDependencies)340 TEST_F(ShapeRefinerTest, InputTensorDependencies) {
341   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
342   Graph graph(OpRegistry::Global());
343   Node* node;
344 
345   Tensor a(DT_FLOAT, TensorShape({}));
346   a.scalar<float>()() = 1.0;
347 
348   Tensor b(DT_FLOAT, TensorShape({}));
349   b.scalar<float>()() = 2.0;
350 
351   Node* input_a = test::graph::Constant(&graph, a);
352   Node* input_b = test::graph::Constant(&graph, b);
353   TF_ASSERT_OK(NodeBuilder("Test", "TestOp")
354                    .Input(input_a)
355                    .Input(input_b)
356                    .Finalize(&graph, &node));
357 
358   TF_ASSERT_OK(m.AddNode(input_a));
359   TF_ASSERT_OK(m.AddNode(input_b));
360   TF_ASSERT_OK(m.AddNode(node));
361   shape_inference::InferenceContext* ctx = m.GetContext(node);
362   EXPECT_EQ("[10,10]", ctx->DebugString(ctx->output(0)));
363 }
364 
365 namespace {
366 
367 // An op with a shape function that looks at its input tensor
368 // data and makes a Shape out of it.
369 REGISTER_OP("ShapeData")
370     .Input("a: int32")
371     .Output("o: int32")
__anon684d22820602(shape_inference::InferenceContext* c) 372     .SetShapeFn([](shape_inference::InferenceContext* c) {
373       const Tensor* shape_data = c->input_tensor(0);
374       if (shape_data == nullptr) {
375         return shape_inference::UnknownShape(c);
376       }
377 
378       std::vector<shape_inference::DimensionHandle> dims;
379       dims.reserve(shape_data->NumElements());
380       for (int i = 0; i < shape_data->NumElements(); ++i) {
381         dims.emplace_back(c->MakeDim(shape_data->flat<int32>()(i)));
382       }
383 
384       c->set_output(0, c->MakeShape(dims));
385       return Status::OK();
386     });
387 
388 REGISTER_OP("ShapeDataInt64")
389     .Input("a: int64")
390     .Output("o: int64")
__anon684d22820702(shape_inference::InferenceContext* c) 391     .SetShapeFn([](shape_inference::InferenceContext* c) {
392       const Tensor* shape_data = c->input_tensor(0);
393       if (shape_data == nullptr) {
394         return shape_inference::UnknownShape(c);
395       }
396 
397       std::vector<shape_inference::DimensionHandle> dims;
398       dims.reserve(shape_data->NumElements());
399       for (int i = 0; i < shape_data->NumElements(); ++i) {
400         dims.emplace_back(c->MakeDim(shape_data->flat<int64>()(i)));
401       }
402 
403       c->set_output(0, c->MakeShape(dims));
404       return Status::OK();
405     });
406 
407 // An op with a shape function that looks at its input tensor
408 // data and makes a rank 1 shape out of the sum of all input values.
409 REGISTER_OP("ShapeVectorForAllElements")
410     .Input("a: int32")
411     .Output("o: int32")
__anon684d22820802(shape_inference::InferenceContext* c) 412     .SetShapeFn([](shape_inference::InferenceContext* c) {
413       const Tensor* shape_data = c->input_tensor(0);
414       if (shape_data == nullptr) {
415         return shape_inference::UnknownShape(c);
416       }
417       int64_t total = 0;
418       for (int i = 0; i < shape_data->NumElements(); ++i) {
419         total += shape_data->flat<int32>()(i);
420       }
421 
422       c->set_output(0, c->Vector(total));
423       return Status::OK();
424     });
425 
426 REGISTER_OP("MultiIdentity")
427     .Input("a: N * int32")
428     .Output("o: N * int32")
429     .Attr("N: int >= 1")
__anon684d22820902(shape_inference::InferenceContext* c) 430     .SetShapeFn([](shape_inference::InferenceContext* c) {
431       for (int i = 0; i < c->num_inputs(); ++i) {
432         c->set_output(i, c->input(i));
433       }
434       return Status::OK();
435     });
436 
437 class MultiIdentity : public OpKernel {
438  public:
MultiIdentity(OpKernelConstruction * c)439   explicit MultiIdentity(OpKernelConstruction* c) : OpKernel(c) {}
440 
Compute(OpKernelContext * c)441   void Compute(OpKernelContext* c) override {
442     for (int i = 0; i < c->num_inputs(); ++i) {
443       c->set_output(i, c->input(i));
444     }
445   }
446 };
447 REGISTER_KERNEL_BUILDER(Name("MultiIdentity").Device(DEVICE_CPU),
448                         MultiIdentity);
449 
450 }  // namespace
451 
TEST_F(ShapeRefinerTest,PropagateShapeAcrossTensorContent)452 TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContent) {
453   Scope root = Scope::NewRootScope();
454 
455   // Create variable 2x4 tensor.
456   auto input = ops::Variable(root, {2, 4}, DT_INT32);
457 
458   // Shape is a vector of 2 elements (2,4)
459   auto shape = ops::Shape(root, input);
460 
461   // Ones for indices of the slice. (get the 4).
462   auto ones = ops::Const(root, {1});
463 
464   // Slice an element of the shape (4).
465   auto sliced = ops::Slice(root, shape, ones, ones);
466 
467   Node* shape_data;
468   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
469                    .Input(sliced.node())
470                    .Finalize(root.graph(), &shape_data));
471 
472   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
473   TF_ASSERT_OK(m.AddNode(ones.node()));
474   TF_ASSERT_OK(m.AddNode(input.node()));
475   TF_ASSERT_OK(m.AddNode(shape.node()));
476   TF_ASSERT_OK(m.AddNode(sliced.node()));
477   TF_ASSERT_OK(m.AddNode(shape_data));
478 
479   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
480   EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
481 }
482 
TEST_F(ShapeRefinerTest,PropagateShapeAcrossTensorContentInt64)483 TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) {
484   Scope root = Scope::NewRootScope();
485 
486   // Create variable 2x4 tensor.
487   auto input = ops::Variable(
488       root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
489       DT_INT64);
490 
491   // Shape is a vector of 2 elements (2,4)
492   auto attrs = ops::Shape::OutType(DT_INT64);
493   auto shape = ops::Shape(root, input, attrs);
494 
495   // Ones for indices of the slice. (get the 4).
496   auto ones = ops::Const(root, {1});
497 
498   // Slice an element of the shape (4).
499   auto sliced = ops::Slice(root, shape, ones, ones);
500 
501   Node* shape_data;
502   TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
503                    .Input(sliced.node())
504                    .Finalize(root.graph(), &shape_data));
505 
506   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
507   TF_ASSERT_OK(m.AddNode(ones.node()));
508   TF_ASSERT_OK(m.AddNode(input.node()));
509   TF_ASSERT_OK(m.AddNode(shape.node()));
510   TF_ASSERT_OK(m.AddNode(sliced.node()));
511   TF_ASSERT_OK(m.AddNode(shape_data));
512 
513   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
514   EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0)));
515 }
516 
TEST_F(ShapeRefinerTest,PropagateShapeAcrossTensorContentInt32Overflow)517 TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) {
518   Scope root = Scope::NewRootScope();
519 
520   // Create variable 2x4 tensor.
521   auto input = ops::Variable(
522       root, {2, 4, static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
523       DT_INT32);
524 
525   // Shape is a vector of 2 elements (2,4)
526   auto shape = ops::Shape(root, input);
527 
528   // Ones for indices of the slice. (get the 4).
529   auto ones = ops::Const(root, {1});
530 
531   // Slice an element of the shape (4).
532   auto sliced = ops::Slice(root, shape, ones, ones);
533 
534   Node* shape_data;
535   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
536                    .Input(sliced.node())
537                    .Finalize(root.graph(), &shape_data));
538 
539   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
540   TF_ASSERT_OK(m.AddNode(ones.node()));
541   TF_ASSERT_OK(m.AddNode(input.node()));
542   TF_ASSERT_OK(m.AddNode(shape.node()));
543   TF_ASSERT_OK(m.AddNode(sliced.node()));
544 
545   // Expect an error since there's an overflow.
546   EXPECT_FALSE(m.AddNode(shape_data).ok());
547 }
548 
TEST_F(ShapeRefinerTest,PropagateRankAcrossTensorContent)549 TEST_F(ShapeRefinerTest, PropagateRankAcrossTensorContent) {
550   Scope root = Scope::NewRootScope();
551 
552   // Create variable 2x4x3 tensor.
553   auto input = ops::Variable(root, {2, 4, 3}, DT_INT32);
554 
555   // Rank 3.
556   auto rank = ops::Rank(root, input);
557 
558   auto identity = ops::Identity(root, rank);
559 
560   Node* shape_data;
561   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
562                    .Input(identity.node())
563                    .Finalize(root.graph(), &shape_data));
564 
565   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
566   TF_ASSERT_OK(m.AddNode(input.node()));
567   TF_ASSERT_OK(m.AddNode(rank.node()));
568   TF_ASSERT_OK(m.AddNode(identity.node()));
569   TF_ASSERT_OK(m.AddNode(shape_data));
570 
571   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
572   EXPECT_EQ("[3]", ctx->DebugString(ctx->output(0)));
573 }
574 
TEST_F(ShapeRefinerTest,PropagateSizeAcrossTensorContent)575 TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContent) {
576   Scope root = Scope::NewRootScope();
577 
578   // Create variable.
579   auto input = ops::Variable(root, {1, 2, 3, 4, 5}, DT_INT32);
580 
581   // 5!.
582   auto size = ops::Size(root, input);
583 
584   auto identity = ops::Identity(root, size);
585 
586   Node* shape_data;
587   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
588                    .Input(identity.node())
589                    .Finalize(root.graph(), &shape_data));
590 
591   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
592   TF_ASSERT_OK(m.AddNode(input.node()));
593   TF_ASSERT_OK(m.AddNode(size.node()));
594   TF_ASSERT_OK(m.AddNode(identity.node()));
595   TF_ASSERT_OK(m.AddNode(shape_data));
596 
597   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
598   EXPECT_EQ("[120]", ctx->DebugString(ctx->output(0)));
599 }
600 
TEST_F(ShapeRefinerTest,PropagateSizeAcrossTensorContentInt64)601 TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) {
602   Scope root = Scope::NewRootScope();
603 
604   // Create variable.
605   auto input =
606       ops::Variable(root,
607                     {1, 2, 3, 4, 5,
608                      static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
609                     DT_INT64);
610 
611   // 5! * int32_max_value * 2.
612   auto attrs = ops::Size::OutType(DT_INT64);
613   auto size = ops::Size(root, input, attrs);
614 
615   auto identity = ops::Identity(root, size);
616 
617   Node* shape_data;
618   TF_ASSERT_OK(NodeBuilder("Test", "ShapeDataInt64")
619                    .Input(identity.node())
620                    .Finalize(root.graph(), &shape_data));
621 
622   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
623   TF_ASSERT_OK(m.AddNode(input.node()));
624   TF_ASSERT_OK(m.AddNode(size.node()));
625   TF_ASSERT_OK(m.AddNode(identity.node()));
626   TF_ASSERT_OK(m.AddNode(shape_data));
627 
628   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
629   EXPECT_EQ("[515396075280]", ctx->DebugString(ctx->output(0)));
630 }
631 
TEST_F(ShapeRefinerTest,PropagateSizeAcrossTensorContentInt32Overflow)632 TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) {
633   Scope root = Scope::NewRootScope();
634 
635   // Create variable.
636   auto input =
637       ops::Variable(root,
638                     {1, 2, 3, 4, 5,
639                      static_cast<int64>(std::numeric_limits<int32>::max()) * 2},
640                     DT_INT32);
641 
642   // 5!.
643   auto size = ops::Size(root, input);
644 
645   auto identity = ops::Identity(root, size);
646 
647   Node* shape_data;
648   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
649                    .Input(identity.node())
650                    .Finalize(root.graph(), &shape_data));
651 
652   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
653   TF_ASSERT_OK(m.AddNode(input.node()));
654   TF_ASSERT_OK(m.AddNode(size.node()));
655   TF_ASSERT_OK(m.AddNode(identity.node()));
656   EXPECT_FALSE(m.AddNode(shape_data).ok());
657 }
658 
TEST_F(ShapeRefinerTest,PropagateShape)659 TEST_F(ShapeRefinerTest, PropagateShape) {
660   Scope root = Scope::NewRootScope();
661   // 3x2 input
662   auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
663 
664   // Shape is a vector of 2 elements (3,2)
665   auto shape = ops::Shape(root, input);
666 
667   Node* shape_data;
668   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
669                    .Input(shape.node())
670                    .Finalize(root.graph(), &shape_data));
671 
672   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
673   TF_ASSERT_OK(m.AddNode(input.node()));
674   TF_ASSERT_OK(m.AddNode(shape.node()));
675   TF_ASSERT_OK(m.AddNode(shape_data));
676 
677   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
678   EXPECT_EQ("[3,2]", ctx->DebugString(ctx->output(0)));
679 }
680 
TEST_F(ShapeRefinerTest,PropagateSize)681 TEST_F(ShapeRefinerTest, PropagateSize) {
682   Scope root = Scope::NewRootScope();
683   // 3x2 input
684   auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
685 
686   auto size = ops::Size(root, input);
687 
688   Node* shape_data;
689   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
690                    .Input(size.node())
691                    .Finalize(root.graph(), &shape_data));
692 
693   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
694   TF_ASSERT_OK(m.AddNode(input.node()));
695   TF_ASSERT_OK(m.AddNode(size.node()));
696   TF_ASSERT_OK(m.AddNode(shape_data));
697 
698   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
699   EXPECT_EQ("[6]", ctx->DebugString(ctx->output(0)));
700 }
701 
TEST_F(ShapeRefinerTest,PropagateRank)702 TEST_F(ShapeRefinerTest, PropagateRank) {
703   Scope root = Scope::NewRootScope();
704   // 3x2 input
705   auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
706 
707   auto rank = ops::Rank(root, input);
708 
709   Node* shape_data;
710   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
711                    .Input(rank.node())
712                    .Finalize(root.graph(), &shape_data));
713 
714   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
715   TF_ASSERT_OK(m.AddNode(input.node()));
716   TF_ASSERT_OK(m.AddNode(rank.node()));
717   TF_ASSERT_OK(m.AddNode(shape_data));
718 
719   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
720   EXPECT_EQ("[2]", ctx->DebugString(ctx->output(0)));
721 }
722 
TEST_F(ShapeRefinerTest,PropagateRange)723 TEST_F(ShapeRefinerTest, PropagateRange) {
724   Scope root = Scope::NewRootScope();
725   auto begin = ops::Const(root, 1);
726   auto limit = ops::Const(root, 11);
727   auto delta = ops::Const(root, 3);
728   auto range = ops::Range(root, begin, limit, delta);
729 
730   Node* shape_data;
731   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
732                    .Input(range.node())
733                    .Finalize(root.graph(), &shape_data));
734 
735   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
736   TF_ASSERT_OK(m.AddNode(begin.node()));
737   TF_ASSERT_OK(m.AddNode(limit.node()));
738   TF_ASSERT_OK(m.AddNode(delta.node()));
739   TF_ASSERT_OK(m.AddNode(range.node()));
740   TF_ASSERT_OK(m.AddNode(shape_data));
741 
742   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
743   EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0)));
744 }
745 
746 // Make sure PlaceholderWithDefaults aren't treated as constants.
TEST_F(ShapeRefinerTest,NoPropagatePlaceholderWithDefault)747 TEST_F(ShapeRefinerTest, NoPropagatePlaceholderWithDefault) {
748   Scope root = Scope::NewRootScope();
749   auto constant = ops::Const<int>(root, 2);
750   auto placeholder =
751       ops::PlaceholderWithDefault(root, constant, PartialTensorShape());
752   Node* shape_data;
753   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
754                    .Input(placeholder.node())
755                    .Finalize(root.graph(), &shape_data));
756 
757   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
758   TF_ASSERT_OK(m.AddNode(constant.node()));
759   TF_ASSERT_OK(m.AddNode(placeholder.node()));
760   TF_ASSERT_OK(m.AddNode(shape_data));
761   shape_inference::InferenceContext* ic = m.GetContext(shape_data);
762   EXPECT_EQ(ic->DebugString(ic->output(0)), "?");
763 }
764 
TEST_F(ShapeRefinerTest,ConstantValueTwoInputsToSameNode)765 TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) {
766   Scope root = Scope::NewRootScope();
767   // This node is used as two inputs to 'range'.
768   auto begin_and_delta = ops::Const(root, 1);
769   auto limit = ops::Const(root, 4);
770   auto range = ops::Range(root, begin_and_delta, limit, begin_and_delta);
771 
772   Node* shape_data;
773   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
774                    .Input(range.node())
775                    .Finalize(root.graph(), &shape_data));
776 
777   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
778   TF_ASSERT_OK(m.AddNode(begin_and_delta.node()));
779   TF_ASSERT_OK(m.AddNode(limit.node()));
780   TF_ASSERT_OK(m.AddNode(range.node()));
781   TF_ASSERT_OK(m.AddNode(shape_data));
782 
783   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
784   EXPECT_EQ("[1,2,3]", ctx->DebugString(ctx->output(0)));
785 }
786 
787 // Creates a graph where 'begin' is attempted to be visited during
788 // constant value evaluation after having been processed once.
TEST_F(ShapeRefinerTest,ConstantValueVisitNodeTwice)789 TEST_F(ShapeRefinerTest, ConstantValueVisitNodeTwice) {
790   Scope root = Scope::NewRootScope();
791   auto begin = ops::Const(root, 1);
792   auto limit = ops::Const(root, 8);
793   auto delta = ops::Const(root, 3);
794 
795   auto d1 = ops::Add(root, begin, limit);  // 9
796   auto d2 = ops::Add(root, begin, delta);  // 4
797   // Visiting flimit's children will visit 'begin' before 'd1'.
798   // It will then visit d1, whose child is 'begin'.  That edge still
799   // must be visited.
800   auto flimit = ops::Sub(root, begin, d1);  // 1-9=-8
801   auto fdelta = ops::Sub(root, begin, d2);  // 1-4=-3
802   auto nl = ops::Abs(root, flimit);         // 8
803   auto nd = ops::Abs(root, fdelta);         // 3
804 
805   auto range = ops::Range(root, begin, nl, nd);
806 
807   Node* shape_data;
808   TF_ASSERT_OK(NodeBuilder("Test", "ShapeData")
809                    .Input(range.node())
810                    .Finalize(root.graph(), &shape_data));
811 
812   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
813   TF_ASSERT_OK(m.AddNode(begin.node()));
814   TF_ASSERT_OK(m.AddNode(limit.node()));
815   TF_ASSERT_OK(m.AddNode(delta.node()));
816   TF_ASSERT_OK(m.AddNode(d1.node()));
817   TF_ASSERT_OK(m.AddNode(d2.node()));
818   TF_ASSERT_OK(m.AddNode(flimit.node()));
819   TF_ASSERT_OK(m.AddNode(fdelta.node()));
820   TF_ASSERT_OK(m.AddNode(nl.node()));
821   TF_ASSERT_OK(m.AddNode(nd.node()));
822   TF_ASSERT_OK(m.AddNode(range.node()));
823   TF_ASSERT_OK(m.AddNode(shape_data));
824 
825   shape_inference::InferenceContext* ctx = m.GetContext(shape_data);
826   EXPECT_EQ("[1,4,7]", ctx->DebugString(ctx->output(0)));
827 }
828 
829 namespace {
830 
TensorAsShapeShapeFn(shape_inference::InferenceContext * c)831 Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) {
832   shape_inference::ShapeHandle out;
833   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0 /* input_idx */, &out));
834   c->set_output(0, out);
835   return Status::OK();
836 }
837 
838 // Register ops used by the ConstantValueAsShape* tests.
839 
840 REGISTER_OP("TensorAsShapeInt32")
841     .Input("a: int32")
842     .Output("o: int32")
843     .SetShapeFn(TensorAsShapeShapeFn);
844 
845 REGISTER_OP("TensorAsShapeInt64")
846     .Input("a: int64")
847     .Output("o: int64")
848     .SetShapeFn(TensorAsShapeShapeFn);
849 
850 REGISTER_OP("NonConstScalarInt32")
851     .Output("o: int32")
852     .SetDoNotOptimize()
853     .SetShapeFn(shape_inference::ScalarShape);
854 
855 REGISTER_OP("NonConstScalarInt64")
856     .Output("o: int64")
857     .SetDoNotOptimize()
858     .SetShapeFn(shape_inference::ScalarShape);
859 
860 REGISTER_OP("WithEmptyVectorShape")
861     .Output("o: int32")
862     .SetDoNotOptimize()
__anon684d22820b02(shape_inference::InferenceContext* c) 863     .SetShapeFn([](shape_inference::InferenceContext* c) {
864       c->set_output(0, c->Vector(0));
865       return Status::OK();
866     });
867 
868 REGISTER_OP("WithPartialShape")
869     .Output("o: int32")
870     .SetDoNotOptimize()
__anon684d22820c02(shape_inference::InferenceContext* c) 871     .SetShapeFn([](shape_inference::InferenceContext* c) {
872       c->set_output(
873           0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3,
874                            shape_inference::InferenceContext::kUnknownDim, 5}));
875       return Status::OK();
876     });
877 
878 REGISTER_OP("WithPartialShape2")
879     .Output("o: int32")
880     .SetDoNotOptimize()
__anon684d22820d02(shape_inference::InferenceContext* c) 881     .SetShapeFn([](shape_inference::InferenceContext* c) {
882       c->set_output(
883           0,
884           c->MakeShape({6, shape_inference::InferenceContext::kUnknownDim, 8}));
885       return Status::OK();
886     });
887 
888 REGISTER_OP("WithUnknownShape")
889     .Output("o: int32")
890     .SetDoNotOptimize()
__anon684d22820e02(shape_inference::InferenceContext* c) 891     .SetShapeFn([](shape_inference::InferenceContext* c) {
892       c->set_output(0, c->UnknownShape());
893       return Status::OK();
894     });
895 
896 }  // namespace
897 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_EmptyVector)898 TEST_F(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) {
899   Scope root = Scope::NewRootScope();
900   Node* input;
901   TF_ASSERT_OK(
902       NodeBuilder("in", "WithEmptyVectorShape").Finalize(root.graph(), &input));
903   Node* result;
904   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
905                    .Input(input)
906                    .Finalize(root.graph(), &result));
907 
908   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
909   TF_ASSERT_OK(m.AddNode(input));
910   TF_ASSERT_OK(m.AddNode(result));
911 
912   shape_inference::InferenceContext* ctx = m.GetContext(result);
913   EXPECT_EQ("[]", ctx->DebugString(ctx->output(0)));
914 }
915 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_Shape)916 TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
917   for (int pass = 0; pass < 2; ++pass) {
918     Scope root = Scope::NewRootScope();
919     Node* input;
920     TF_ASSERT_OK(
921         NodeBuilder("in", pass == 0 ? "WithPartialShape" : "WithUnknownShape")
922             .Finalize(root.graph(), &input));
923     auto shape = ops::Shape(root, Output(input));
924     Node* result;
925     TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
926                      .Input(shape.node())
927                      .Finalize(root.graph(), &result));
928 
929     ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
930     TF_ASSERT_OK(m.AddNode(input));
931     TF_ASSERT_OK(m.AddNode(shape.node()));
932     TF_ASSERT_OK(m.AddNode(result));
933 
934     shape_inference::InferenceContext* ctx = m.GetContext(result);
935     if (pass == 0) {
936       EXPECT_EQ("[1,?,3,?,5]", ctx->DebugString(ctx->output(0)));
937     } else {
938       EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
939     }
940   }
941 }
942 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackInt32)943 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
944   Scope root = Scope::DisabledShapeInferenceScope();
945   Node* scalar_non_const;
946   TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
947                    .Finalize(root.graph(), &scalar_non_const));
948 
949   InputList inputs{
950       // clang-format off
951       Input(ops::Const<int32>(root, 10)),
952       Input(ops::Const<int32>(root, 20)),
953       Input(Output(scalar_non_const)),
954       Input(ops::Const<int32>(root, 40)),
955   };  // clang-format on
956   auto pack = ops::Stack(root, inputs);
957   TF_ASSERT_OK(root.status());
958 
959   Node* result;
960   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
961                    .Input(pack.node())
962                    .Finalize(root.graph(), &result));
963 
964   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
965   for (const auto& input : inputs) {
966     TF_ASSERT_OK(m.AddNode(input.node()));
967   }
968   TF_ASSERT_OK(m.AddNode(pack.node()));
969   TF_ASSERT_OK(m.AddNode(result));
970 
971   shape_inference::InferenceContext* ctx = m.GetContext(result);
972   EXPECT_EQ("[10,20,?,40]", ctx->DebugString(ctx->output(0)));
973 }
974 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackInt64)975 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
976   Scope root = Scope::DisabledShapeInferenceScope();
977   Node* scalar_non_const;
978   TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
979                    .Finalize(root.graph(), &scalar_non_const));
980 
981   InputList inputs{
982       // clang-format off
983       Input(ops::Const<int64>(root, int64{10})),
984       Input(ops::Const<int64>(root, int64{20})),
985       Input(Output(scalar_non_const)),
986       Input(ops::Const<int64>(root, int64{1} << 40)),
987   };  // clang-format on
988   auto pack = ops::Stack(root, inputs);
989   TF_ASSERT_OK(root.status());
990 
991   Node* result;
992   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
993                    .Input(pack.node())
994                    .Finalize(root.graph(), &result));
995 
996   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
997   for (const auto& input : inputs) {
998     TF_ASSERT_OK(m.AddNode(input.node()));
999   }
1000   TF_ASSERT_OK(m.AddNode(pack.node()));
1001   TF_ASSERT_OK(m.AddNode(result));
1002 
1003   shape_inference::InferenceContext* ctx = m.GetContext(result);
1004   EXPECT_EQ("[10,20,?,1099511627776]", ctx->DebugString(ctx->output(0)));
1005 }
1006 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackUnknownDim)1007 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) {
1008   Scope root = Scope::NewRootScope();
1009 
1010   InputList inputs{
1011       Input(ops::Const<int64>(root, int64{10})),
1012       Input(ops::Const<int64>(root, int64{-1})),
1013   };
1014   auto pack = ops::Stack(root, inputs);
1015   TF_ASSERT_OK(root.status());
1016 
1017   Node* result;
1018   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
1019                    .Input(pack.node())
1020                    .Finalize(root.graph(), &result));
1021 
1022   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1023   for (const auto& input : inputs) {
1024     TF_ASSERT_OK(m.AddNode(input.node()));
1025   }
1026   TF_ASSERT_OK(m.AddNode(pack.node()));
1027   TF_ASSERT_OK(m.AddNode(result));
1028 
1029   shape_inference::InferenceContext* ctx = m.GetContext(result);
1030   EXPECT_EQ("[10,?]", ctx->DebugString(ctx->output(0)));
1031 }
1032 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_PackInvalidInput)1033 TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
1034   Scope root = Scope::NewRootScope();
1035 
1036   // Inputs are length 2 vectors instead of scalars.
1037   InputList inputs{
1038       Input(ops::Const<int64>(root, {int64{10}, int64{20}})),
1039       Input(ops::Const<int64>(root, {int64{10}, int64{21}})),
1040   };
1041   auto pack = ops::Stack(root, inputs);
1042   TF_ASSERT_OK(root.status());
1043 
1044   Node* result;
1045   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt64")
1046                    .Input(pack.node())
1047                    .Finalize(root.graph(), &result));
1048 
1049   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1050   for (const auto& input : inputs) {
1051     TF_ASSERT_OK(m.AddNode(input.node()));
1052   }
1053   TF_ASSERT_OK(m.AddNode(pack.node()));
1054   EXPECT_TRUE(
1055       absl::StrContains(m.AddNode(result).error_message(), "but is rank 2"));
1056 }
1057 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_Concat)1058 TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
1059   Scope root = Scope::DisabledShapeInferenceScope();
1060   Graph* g = root.graph();
1061   Node* partial_1;
1062   Node* partial_2;
1063   TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
1064   TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
1065   auto const_input = ops::Const(root, {9, 10, 11});
1066   OutputList concat_inputs{
1067       // clang-format off
1068       ops::Shape(root, Output(partial_1)),
1069       ops::Shape(root, Output(partial_2)),
1070       const_input,
1071   };  // clang-format on
1072   auto concat_dim = ops::Const(root, 0);
1073   auto concat = ops::Concat(root, concat_inputs, concat_dim);
1074   TF_ASSERT_OK(root.status());
1075 
1076   Node* result;
1077   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1078                    .Input(concat.node())
1079                    .Finalize(g, &result));
1080 
1081   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1082   TF_ASSERT_OK(m.AddNode(partial_1));
1083   TF_ASSERT_OK(m.AddNode(partial_2));
1084   for (const auto& o : concat_inputs) {
1085     TF_ASSERT_OK(m.AddNode(o.node()));
1086   }
1087   TF_ASSERT_OK(m.AddNode(concat_dim.node()));
1088   TF_ASSERT_OK(m.AddNode(concat.node()));
1089   TF_ASSERT_OK(m.AddNode(result));
1090 
1091   shape_inference::InferenceContext* ctx = m.GetContext(result);
1092   EXPECT_EQ("[1,?,3,?,5,6,?,8,9,10,11]", ctx->DebugString(ctx->output(0)));
1093 }
1094 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_ConcatWithUnknown)1095 TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
1096   Scope root = Scope::DisabledShapeInferenceScope();
1097   Graph* g = root.graph();
1098   Node* scalar_non_const;
1099   TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
1100                    .Finalize(root.graph(), &scalar_non_const));
1101 
1102   Node* partial_1;
1103   Node* partial_2;
1104   Node* unknown;
1105   TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
1106   TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
1107   TF_ASSERT_OK(NodeBuilder("in", "WithUnknownShape").Finalize(g, &unknown));
1108   OutputList concat_inputs{
1109       // clang-format off
1110       ops::Shape(root, Output(partial_1)),
1111       ops::Shape(root, Output(partial_2)),
1112       ops::Shape(root, Output(unknown)),
1113   };  // clang-format on
1114   auto concat_dim = ops::Const(root, 0);
1115   auto concat = ops::Concat(root, concat_inputs, concat_dim);
1116   TF_ASSERT_OK(root.status());
1117 
1118   Node* result;
1119   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1120                    .Input(concat.node())
1121                    .Finalize(g, &result));
1122 
1123   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1124   TF_ASSERT_OK(m.AddNode(partial_1));
1125   TF_ASSERT_OK(m.AddNode(partial_2));
1126   TF_ASSERT_OK(m.AddNode(unknown));
1127   for (const auto& o : concat_inputs) {
1128     TF_ASSERT_OK(m.AddNode(o.node()));
1129   }
1130   TF_ASSERT_OK(m.AddNode(concat_dim.node()));
1131   TF_ASSERT_OK(m.AddNode(concat.node()));
1132   TF_ASSERT_OK(m.AddNode(result));
1133 
1134   shape_inference::InferenceContext* ctx = m.GetContext(result);
1135   EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
1136 }
1137 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_ConcatInvalidDimValue)1138 TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
1139   Scope root = Scope::DisabledShapeInferenceScope();
1140   Graph* g = root.graph();
1141   Node* scalar_non_const;
1142   TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
1143                    .Finalize(root.graph(), &scalar_non_const));
1144 
1145   Node* partial_1;
1146   Node* partial_2;
1147   TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape").Finalize(g, &partial_1));
1148   TF_ASSERT_OK(NodeBuilder("in", "WithPartialShape2").Finalize(g, &partial_2));
1149   auto const_input = ops::Const(root, {9, -2, 11});
1150   OutputList concat_inputs{
1151       // clang-format off
1152       ops::Shape(root, Output(partial_1)),
1153       ops::Shape(root, Output(partial_2)),
1154       const_input,
1155   };  // clang-format on
1156   auto concat_dim = ops::Const(root, 0);
1157   auto concat = ops::Concat(root, concat_inputs, concat_dim);
1158   TF_ASSERT_OK(root.status());
1159 
1160   Node* result;
1161   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1162                    .Input(concat.node())
1163                    .Finalize(g, &result));
1164 
1165   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1166   TF_ASSERT_OK(m.AddNode(partial_1));
1167   TF_ASSERT_OK(m.AddNode(partial_2));
1168   for (const auto& o : concat_inputs) {
1169     TF_ASSERT_OK(m.AddNode(o.node()));
1170   }
1171   TF_ASSERT_OK(m.AddNode(concat_dim.node()));
1172   TF_ASSERT_OK(m.AddNode(concat.node()));
1173   EXPECT_EQ("Invalid value in tensor used for shape: -2",
1174             m.AddNode(result).error_message());
1175 }
1176 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSlice)1177 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) {
1178   TestStridedSlice(
1179       /*input_shape=*/{1, -1, 3, -1, 5},
1180       /*begin=*/2,
1181       /*end=*/5,
1182       /*stride=*/1,
1183       /*expected=*/"[3,?,5]");
1184 }
1185 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceNegativeStride)1186 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) {
1187   // clang-format off
1188   TestStridedSlice(
1189       /*input_shape=*/{1, -1, 3, -1, 5},
1190       /*begin=*/10,
1191       /*end=*/0,
1192       /*stride=*/-1,
1193       /*expected=*/"[5,?,3,?]");
1194   // clang-format on
1195 }
1196 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceMasks)1197 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) {
1198   TestStridedSlice(
1199       /*input_shape=*/{1, -1, 3, -1, 5},
1200       /*begin=*/3,
1201       /*end=*/4,
1202       /*stride=*/1,
1203       /*expected=*/"[1,?,3,?,5]",
1204       /*begin_mask=*/1,
1205       /*end_mask=*/1);
1206 }
1207 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceInvalidMask)1208 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) {
1209   TestStridedSlice(
1210       /*input_shape=*/{1, -1, 3},
1211       /*begin=*/2,
1212       /*end=*/3,
1213       /*stride=*/1,
1214       /*expected=*/"[?,?,?]",
1215       /*begin_mask=*/0,
1216       /*end_mask=*/0,
1217       /*ellipsis_mask=*/1);
1218 }
1219 
TEST_F(ShapeRefinerTest,ConstantValueAsShape_StridedSliceMulti)1220 TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) {
1221   Scope root = Scope::DisabledShapeInferenceScope();
1222   auto input = ops::Placeholder(root, DT_INT32);
1223   auto begin = ops::Const(root, {0, 0});
1224   auto end = ops::Const(root, {2, 2});
1225   auto stride = ops::Const(root, {1, 1});
1226   auto slice = ops::StridedSlice(root, input, begin, end, stride);
1227   Node* result;
1228   TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
1229                    .Input(slice.node())
1230                    .Finalize(root.graph(), &result));
1231 
1232   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1233   TF_ASSERT_OK(m.AddNode(input.node()));
1234   TF_ASSERT_OK(m.AddNode(begin.node()));
1235   TF_ASSERT_OK(m.AddNode(end.node()));
1236   TF_ASSERT_OK(m.AddNode(stride.node()));
1237   TF_ASSERT_OK(m.AddNode(slice.node()));
1238   TF_ASSERT_OK(m.AddNode(result));
1239 
1240   shape_inference::InferenceContext* ctx = m.GetContext(result);
1241   EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?");
1242 }
1243 
1244 namespace {
1245 
1246 // Dummy op to test ShapeRefiner util functions
1247 REGISTER_OP("Dummy");
1248 
1249 }  // namespace
1250 
TEST_F(ShapeRefinerTest,SameDefinedShape)1251 TEST_F(ShapeRefinerTest, SameDefinedShape) {
1252   Scope root = Scope::NewRootScope();
1253   Graph* g = root.graph();
1254   Node* test;
1255   TF_CHECK_OK(NodeBuilder("test", "Dummy").Finalize(g, &test));
1256   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1257   m.set_require_shape_inference_fns(false);
1258   TF_ASSERT_OK(m.AddNode(test));
1259   shape_inference::InferenceContext* ctx = m.GetContext(test);
1260 
1261   auto unknown = ctx->UnknownShape();
1262   auto unknown_b = ctx->UnknownShape();
1263   auto s_1_2 = ctx->MakeShape({1, 2});
1264   auto s_1_2_b = ctx->MakeShape({1, 2});
1265   auto s_2_2 = ctx->MakeShape({2, 2});
1266   auto s_unknown_2 = ctx->MakeShape({-1, 2});
1267   auto s_unknown_2_b = ctx->MakeShape({-1, 2});
1268 
1269   EXPECT_TRUE(SameDefinedShape(ctx, unknown, unknown));
1270   EXPECT_FALSE(SameDefinedShape(ctx, unknown, unknown_b));
1271   EXPECT_FALSE(SameDefinedShape(ctx, unknown, s_1_2));
1272   EXPECT_TRUE(SameDefinedShape(ctx, s_1_2, s_1_2_b));
1273   EXPECT_FALSE(SameDefinedShape(ctx, s_1_2, s_2_2));
1274   EXPECT_TRUE(SameDefinedShape(ctx, s_unknown_2, s_unknown_2));
1275   EXPECT_FALSE(SameDefinedShape(ctx, s_unknown_2, s_unknown_2_b));
1276 }
1277 
TEST_F(ShapeRefinerTest,IsUpdatedShapesOrTypes)1278 TEST_F(ShapeRefinerTest, IsUpdatedShapesOrTypes) {
1279   Scope root = Scope::NewRootScope();
1280   Graph* g = root.graph();
1281   Node* test;
1282   TF_CHECK_OK(NodeBuilder("test", "Dummy").Finalize(g, &test));
1283   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1284   m.set_require_shape_inference_fns(false);
1285   TF_ASSERT_OK(m.AddNode(test));
1286   shape_inference::InferenceContext* ctx = m.GetContext(test);
1287 
1288   shape_inference::ShapeHandle unknown = ctx->UnknownShape();
1289   std::vector<shape_inference::ShapeAndType> t0{
1290       {ctx->MakeShape({1, 2, 3}), DT_FLOAT},
1291       {unknown, DT_INVALID},
1292       {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1293 
1294   std::vector<shape_inference::ShapeAndType> t1{
1295       {ctx->MakeShape({1, 2, 3}), DT_FLOAT},
1296       {unknown, DT_INVALID},
1297       {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1298 
1299   std::vector<shape_inference::ShapeAndType> t2{
1300       {ctx->MakeShape({1, 2, 4}), DT_FLOAT},
1301       {ctx->UnknownShape(), DT_INVALID},
1302       {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1303 
1304   std::vector<shape_inference::ShapeAndType> t3{
1305       {ctx->MakeShape({1, 2, 3}), DT_INT32},
1306       {ctx->UnknownShape(), DT_INVALID},
1307       {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}};
1308 
1309   EXPECT_FALSE(IsUpdatedShapesOrTypes(ctx, t0, t1));
1310 
1311   // A shape has been modified
1312   EXPECT_TRUE(IsUpdatedShapesOrTypes(ctx, t0, t2));
1313 
1314   // A type has been modified
1315   EXPECT_TRUE(IsUpdatedShapesOrTypes(ctx, t0, t3));
1316 }
1317 
TEST_F(ShapeRefinerTest,IncrementalUpdates)1318 TEST_F(ShapeRefinerTest, IncrementalUpdates) {
1319   Scope root = Scope::NewRootScope();
1320   Graph* g = root.graph();
1321   Node* queue;
1322   TF_CHECK_OK(NodeBuilder("queue", "FIFOQueueV2")
1323                   .Attr("component_types", {DT_FLOAT})
1324                   .Finalize(g, &queue));
1325   Node* dequeue;
1326   TF_CHECK_OK(NodeBuilder("dequeue", "QueueDequeueV2")
1327                   .Attr("component_types", {DT_FLOAT})
1328                   .Input(queue)
1329                   .Finalize(g, &dequeue));
1330   ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
1331   TF_ASSERT_OK(m.AddNode(queue));
1332   TF_ASSERT_OK(m.AddNode(dequeue));
1333 
1334   // At this point, the shapes of the dequeued tensor are unknown.
1335   shape_inference::InferenceContext* ctx = m.GetContext(dequeue);
1336   EXPECT_EQ("?", ctx->DebugString(ctx->output(0)));
1337 
1338   // Inject a shape, and incrementally propagate it to the dequeue op.
1339   ctx = m.GetContext(queue);
1340   shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
1341   ctx->set_output_handle_shapes_and_types(
1342       0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
1343   bool refined = false;
1344   TF_ASSERT_OK(m.UpdateNode(dequeue, false /* relax */, &refined));
1345   EXPECT_TRUE(refined);
1346   ctx = m.GetContext(dequeue);
1347   EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0)));
1348 
1349   // Inject another shape, but relax instead of merge.
1350   ctx = m.GetContext(queue);
1351   shp = ctx->MakeShape({2, 7});
1352   ctx->set_output_handle_shapes_and_types(
1353       0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
1354   refined = false;
1355   TF_ASSERT_OK(m.UpdateNode(dequeue, true /* relax */, &refined));
1356   EXPECT_TRUE(refined);
1357   ctx = m.GetContext(dequeue);
1358   EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
1359 
1360   // Inject another partially unknown shape and attempt to relax it.
1361   ctx = m.GetContext(queue);
1362   shp = ctx->MakeShape({shape_inference::InferenceContext::kUnknownDim, 7});
1363   ctx->set_output_handle_shapes_and_types(
1364       0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
1365   refined = false;
1366   TF_ASSERT_OK(m.UpdateNode(dequeue, true /* relax */, &refined));
1367   EXPECT_TRUE(refined);
1368   ctx = m.GetContext(dequeue);
1369   EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0)));
1370   EXPECT_TRUE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0)));
1371 
1372   // Inject a shape of the same handle and expect refined to not change.
1373   ctx = m.GetContext(queue);
1374   shape_inference::ShapeHandle shp2 = shp;
1375   ctx->set_output_handle_shapes_and_types(
1376       0, std::vector<shape_inference::ShapeAndType>{{shp2, DT_FLOAT}});
1377   refined = false;
1378   TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined));
1379   EXPECT_FALSE(refined);
1380   EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0)));
1381 }
1382 
TestSimpleFunctionInference(bool enable_function_inference)1383 void TestSimpleFunctionInference(bool enable_function_inference) {
1384   FunctionDefLibrary f_lib_proto;
1385   *(f_lib_proto.add_function()) = test::function::XTimesTwo();
1386   FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1387 
1388   Scope root = Scope::NewRootScope();
1389   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1390   auto x = ops::Const(root, {{1.0f, 2.0f}});
1391   auto x2 = test::function::Call(&root, "x2", "XTimesTwo", {x});
1392 
1393   ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1394   if (enable_function_inference) {
1395     m.set_function_library_for_shape_inference(&f_lib);
1396   }
1397 
1398   TF_ASSERT_OK(m.AddNode(x.node()));
1399   TF_ASSERT_OK(m.AddNode(x2.node()));
1400 
1401   EXPECT_SHAPE("[1,2]", m, x, 0);
1402 
1403   if (enable_function_inference) {
1404     EXPECT_SHAPE("[1,2]", m, x2, 0);
1405   } else {
1406     // Default inference behavior: functions output shapes are unknown.
1407     EXPECT_SHAPE("?", m, x2, 0);
1408   }
1409 }
1410 
TEST_F(ShapeRefinerTest,SimpleFunctionShapeInference_Disabled)1411 TEST_F(ShapeRefinerTest, SimpleFunctionShapeInference_Disabled) {
1412   // Nesting flag doesn't matter, when function inference is disabled.
1413   TestSimpleFunctionInference(false /* enable_function_inference */);
1414 }
1415 
TEST_F(ShapeRefinerTest,SimpleFunctionShapeInference)1416 TEST_F(ShapeRefinerTest, SimpleFunctionShapeInference) {
1417   TestSimpleFunctionInference(true /* enable_function_inference */);
1418 }
1419 
TEST_F(ShapeRefinerTest,FunctionShapeInferenceFallback)1420 TEST_F(ShapeRefinerTest, FunctionShapeInferenceFallback) {
1421   // Test that function inference falls back to returning unknown shapes,
1422   // if the function lookup fails.
1423 
1424   FunctionDefLibrary f_lib_proto;
1425   *(f_lib_proto.add_function()) = test::function::XTimesTwo();
1426   FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1427 
1428   Scope root = Scope::NewRootScope();
1429   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1430   auto x = ops::Const(root, {{.0f, .0f}});
1431   auto x2 = test::function::Call(&root, "x2", "XTimesTwo", {x});
1432 
1433   FunctionDefLibrary empty_f_lib_proto;
1434   FunctionLibraryDefinition empty_f_lib(OpRegistry::Global(),
1435                                         empty_f_lib_proto);
1436 
1437   ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1438   m.set_function_library_for_shape_inference(&empty_f_lib);
1439 
1440   TF_ASSERT_OK(m.AddNode(x.node()));
1441   TF_ASSERT_OK(m.AddNode(x2.node()));
1442 
1443   EXPECT_SHAPE("[1,2]", m, x, 0);
1444 
1445   // Default inference behavior: functions output shapes are unknown.
1446   EXPECT_SHAPE("?", m, x2, 0);
1447 }
1448 
TEST_F(ShapeRefinerTest,ChainedFunctionShapeInferenceWithMultipleInputs)1449 TEST_F(ShapeRefinerTest, ChainedFunctionShapeInferenceWithMultipleInputs) {
1450   FunctionDefLibrary f_lib_proto;
1451   *(f_lib_proto.add_function()) = test::function::XTimesTwo();
1452   *(f_lib_proto.add_function()) = test::function::XTimesFour();
1453   *(f_lib_proto.add_function()) = test::function::XTimes16();
1454   *(f_lib_proto.add_function()) = test::function::WXPlusB();
1455   FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1456 
1457   Scope root = Scope::NewRootScope();
1458   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1459   auto w = ops::Const(root, {{.0f}, {.0f}, {.0f}});
1460   auto x = ops::Const(root, {{.0f, .0f, .0f}});
1461   auto b = ops::Const(root, {{.0f}});
1462 
1463   auto wxplusb = test::function::Call(&root, "wxplusb", "WXPlusB", {w, x, b});
1464   auto wxplusb16 =
1465       test::function::Call(&root, "wxplusb16", "XTimes16", {wxplusb});
1466 
1467   ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1468   m.set_function_library_for_shape_inference(&f_lib);
1469 
1470   TF_ASSERT_OK(m.AddNode(w.node()));
1471   TF_ASSERT_OK(m.AddNode(x.node()));
1472   TF_ASSERT_OK(m.AddNode(b.node()));
1473   TF_ASSERT_OK(m.AddNode(wxplusb.node()));
1474   TF_ASSERT_OK(m.AddNode(wxplusb16.node()));
1475 
1476   EXPECT_SHAPE("[3,1]", m, w, 0);
1477   EXPECT_SHAPE("[1,3]", m, x, 0);
1478   EXPECT_SHAPE("[1,1]", m, b, 0);
1479   EXPECT_SHAPE("[3,3]", m, wxplusb, 0);
1480   EXPECT_SHAPE("[3,3]", m, wxplusb16, 0);
1481 }
1482 
TEST_F(ShapeRefinerTest,FunctionShapeInferenceWorksForResourceHandles)1483 TEST_F(ShapeRefinerTest, FunctionShapeInferenceWorksForResourceHandles) {
1484   FunctionDefLibrary f_lib_proto;
1485   *(f_lib_proto.add_function()) = test::function::Swap();
1486 
1487   FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
1488 
1489   Scope root = Scope::NewRootScope().ExitOnError();
1490   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
1491 
1492   auto x1 = ops::VarHandleOp(root, DataType::DT_FLOAT, TensorShape({128, 256}));
1493   auto x2 = ops::VarHandleOp(root, DataType::DT_DOUBLE, TensorShape({1024}));
1494   auto swap = test::function::Call(&root, "swap", "Swap", {x1, x2});
1495 
1496   EXPECT_EQ(swap.node()->num_outputs(), 2);
1497 
1498   ShapeRefiner m(TF_GRAPH_DEF_VERSION, &f_lib);
1499   m.set_function_library_for_shape_inference(&f_lib);
1500 
1501   TF_ASSERT_OK(m.AddNode(x1.node()));
1502   TF_ASSERT_OK(m.AddNode(x2.node()));
1503   TF_ASSERT_OK(m.AddNode(swap.node()));
1504 
1505   EXPECT_EQ(m.GetContext(swap.node())->num_outputs(), 2);
1506 
1507   EXPECT_RESOURCE_SINGLE_SHAPE("[128,256]", m, x1, 0);
1508   EXPECT_RESOURCE_SINGLE_SHAPE("[1024]", m, x2, 0);
1509   EXPECT_RESOURCE_SINGLE_SHAPE("[1024]", m, swap, 0);
1510   EXPECT_RESOURCE_SINGLE_SHAPE("[128,256]", m, swap, 1);
1511   EXPECT_RESOURCE_SINGLE_TYPE(DataType::DT_DOUBLE, m, swap, 0);
1512   EXPECT_RESOURCE_SINGLE_TYPE(DataType::DT_FLOAT, m, swap, 1);
1513 }
1514 
1515 }  // namespace
1516 }  // namespace tensorflow
1517