• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17 
18 #include "tensorflow/cc/ops/resource_variable_ops.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/function_testlib.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/core/public/session_options.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class FreezeTest : public ::testing::Test {
33  protected:
GraphDefEqual(const GraphDef & actual,const GraphDef & expected)34   void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) {
35     EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString());
36   }
37 
38   // Builds a SignatureDef with the provided `inputs` and `outputs`.
BuildSignatureDef(const std::unordered_set<string> & inputs,const std::unordered_set<string> & outputs)39   SignatureDef BuildSignatureDef(const std::unordered_set<string>& inputs,
40                                  const std::unordered_set<string>& outputs) {
41     SignatureDef signature_def;
42     for (const string& input : inputs) {
43       (*signature_def.mutable_inputs())[input].set_name(input);
44     }
45     for (const string& output : outputs) {
46       (*signature_def.mutable_outputs())[output].set_name(output);
47     }
48     return signature_def;
49   }
50 
51   // Adds `signature_def` to `saved_model_bundle` under `key`.
AddSignatureDefToSavedModelBundle(const SignatureDef & signature_def,const string & key,SavedModelBundle * saved_model_bundle)52   void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def,
53                                          const string& key,
54                                          SavedModelBundle* saved_model_bundle) {
55     MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
56     (*meta_graph_def->mutable_signature_def())[key] = signature_def;
57   }
58 
59   // Adds an initialized session to `saved_model_bundle` using `graph_def` and
60   // initializing with `init_node`.
InitializeSavedModelBundleSession(const GraphDef & graph_def,const string & init_node,SavedModelBundle * saved_model_bundle)61   Status InitializeSavedModelBundleSession(
62       const GraphDef& graph_def, const string& init_node,
63       SavedModelBundle* saved_model_bundle) {
64     SessionOptions session_options;
65     saved_model_bundle->session.reset(NewSession(session_options));
66     TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def));
67     if (!init_node.empty()) {
68       std::vector<Tensor> outputs;
69       return saved_model_bundle->session->Run(
70           /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs);
71     }
72     return Status::OK();
73   }
74 
75   // Adds `graph_def` to `saved_model_bundle` and initializes a session with
76   // `init_node`.
AddGraphDefToSavedModelBundle(const GraphDef & graph_def,const string & init_node,SavedModelBundle * saved_model_bundle)77   Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def,
78                                        const string& init_node,
79                                        SavedModelBundle* saved_model_bundle) {
80     MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
81     *meta_graph_def->mutable_graph_def() = graph_def;
82     return InitializeSavedModelBundleSession(graph_def, init_node,
83                                              saved_model_bundle);
84   }
85 
86   // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in
87   // `saved_model_bundle` and initializes a session with `init_node`.
AddGraphDefWithOutputsToSavedModelBundle(const GraphDef & graph_def,const std::unordered_set<string> & outputs,const string & init_node,SavedModelBundle * saved_model_bundle)88   Status AddGraphDefWithOutputsToSavedModelBundle(
89       const GraphDef& graph_def, const std::unordered_set<string>& outputs,
90       const string& init_node, SavedModelBundle* saved_model_bundle) {
91     SignatureDef signature_def =
92         BuildSignatureDef(std::unordered_set<string>(), outputs);
93     AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
94                                       saved_model_bundle);
95     return AddGraphDefToSavedModelBundle(graph_def, init_node,
96                                          saved_model_bundle);
97   }
98 
99   // Runs and compares the outputs of `tensor_name` on both the
100   // `unfrozen_session` and the `frozen_graph_def.
RunAndCompareFrozenAndUnfrozenGraphs(Session * unfrozen_session,const GraphDef & frozen_graph_def,const string & tensor_name)101   void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session,
102                                             const GraphDef& frozen_graph_def,
103                                             const string& tensor_name) {
104     std::vector<Tensor> unfrozen_outputs;
105     TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name},
106                                        /* targets */ {}, &unfrozen_outputs));
107 
108     SessionOptions session_options;
109     std::unique_ptr<Session> frozen_session(NewSession(session_options));
110     TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
111     std::vector<Tensor> frozen_outputs;
112     TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name},
113                                      /* targets */ {}, &frozen_outputs));
114 
115     test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]);
116   }
117 
TestFreezeGraphWithoutDependentVariables(bool use_resource)118   void TestFreezeGraphWithoutDependentVariables(bool use_resource) {
119     // Test freezing a graph with variables that are not needed by the outputs
120     // in the SignatureDef. The resulting graph shouldn't be frozen, but
121     // non-dependent nodes should be pruned.
122     SavedModelBundle saved_model_bundle;
123     GraphDef graph_def;
124     Scope scope = Scope::NewRootScope();
125     Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
126     Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
127     Output c = ops::Mul(scope.WithOpName("c"), a, b);
128     if (use_resource) {
129       Output var =
130           ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
131       Output read_var = ops::ReadVariableOp(
132           scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
133       auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
134     } else {
135       Output var =
136           ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
137       Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
138     }
139 
140     TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
141     // "c" isn't dependent on the variable, so nothing should be frozen.
142     TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
143         graph_def, {"c:0"}, "assign", &saved_model_bundle));
144 
145     GraphDef frozen_graph_def;
146     std::unordered_set<string> inputs;
147     std::unordered_set<string> outputs;
148     TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
149                                   &inputs, &outputs));
150 
151     GraphDef expected_graph_def;
152     Scope expected_scope = Scope::NewRootScope();
153     Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
154     Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
155     Output expected_c =
156         ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
157     TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
158 
159     GraphDefEqual(frozen_graph_def, expected_graph_def);
160 
161     RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
162                                          frozen_graph_def, "c:0");
163   }
164 
TestFreezeGraphWithDependentVariables(bool use_resource)165   void TestFreezeGraphWithDependentVariables(bool use_resource) {
166     // Test freezing a graph with variables that are needed by outputs in the
167     // SignatureDef. The variables should be frozen.
168     SavedModelBundle saved_model_bundle;
169     GraphDef graph_def;
170     Scope scope = Scope::NewRootScope();
171     Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
172     Output read_var;
173     if (use_resource) {
174       Output var =
175           ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
176       read_var = ops::ReadVariableOp(
177           scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
178       auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
179     } else {
180       Output read_var =
181           ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
182       Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a);
183     }
184     Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
185     TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
186     // "c" isn't dependent on the variable, so nothing should be frozen.
187     TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
188         graph_def, {"c:0"}, "assign", &saved_model_bundle));
189 
190     GraphDef frozen_graph_def;
191     std::unordered_set<string> inputs;
192     std::unordered_set<string> outputs;
193     TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
194                                   &inputs, &outputs));
195 
196     // If using normal variables there should be 3 nodes in the resulting
197     // graph_def. If using resource variables there should be 4 nodes in the
198     // resulting graph_def.
199     // In both cases, none should be variables.
200     size_t expected_nodes = use_resource ? 4 : 3;
201     EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes);
202     for (const NodeDef& node : frozen_graph_def.node()) {
203       EXPECT_NE(node.op(), "Variable") << node.name();
204       EXPECT_NE(node.op(), "VariableV2") << node.name();
205       EXPECT_NE(node.op(), "VarHandleOp") << node.name();
206       EXPECT_NE(node.op(), "ReadVariableOp") << node.name();
207     }
208 
209     RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
210                                          frozen_graph_def, "c:0");
211   }
212 
TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource)213   void TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource) {
214     // Test freezing a graph with some variables that are needed and not needed
215     // by
216     // the outputs in the SignatureDef. The resulting graph should only freeze
217     // dependent variables.
218     SavedModelBundle saved_model_bundle;
219     GraphDef graph_def;
220     Scope scope = Scope::NewRootScope();
221     Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
222     Output read_var;
223 
224     if (use_resource) {
225       Output var =
226           ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
227       read_var = ops::ReadVariableOp(
228           scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
229       auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
230       Output var_1 =
231           ops::VarHandleOp(scope.WithOpName("var_1"), DataType::DT_FLOAT, {});
232       Output read_var_1 =
233           ops::ReadVariableOp(scope.WithOpName("var_1/Read/ReadVariableOp"),
234                               var, DataType::DT_FLOAT);
235       auto assign_1 =
236           ops::AssignVariableOp(scope.WithOpName("assign_1"), var_1, a);
237     } else {
238       read_var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
239       Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a);
240       Output var_1 =
241           ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
242       Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var_1, a);
243     }
244 
245     Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
246     TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
247     // "c" isn't dependent on the variable, so nothing should be frozen.
248     TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
249         graph_def, {"c:0"}, "assign", &saved_model_bundle));
250 
251     GraphDef frozen_graph_def;
252     std::unordered_set<string> inputs;
253     std::unordered_set<string> outputs;
254     TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
255                                   &inputs, &outputs));
256 
257     // There should be 3 nodes in the resulting graph_def, and none should be
258     // variables.
259     size_t expected_nodes = use_resource ? 4 : 3;
260     EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes);
261     for (const NodeDef& node : frozen_graph_def.node()) {
262       EXPECT_NE(node.op(), "Variable") << node.name();
263       EXPECT_NE(node.op(), "VariableV2") << node.name();
264       EXPECT_NE(node.op(), "VarHandleOp") << node.name();
265       EXPECT_NE(node.op(), "ReadVariableOp") << node.name();
266     }
267 
268     RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
269                                          frozen_graph_def, "c:0");
270   }
271 };
272 
TEST_F(FreezeTest,InputsAndOutputsSingleSignatureDef)273 TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) {
274   // Test that inputs and outputs get correctly populated for a single
275   // SignatureDef.
276   SavedModelBundle saved_model_bundle;
277   std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
278   std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
279   SignatureDef signature_def =
280       BuildSignatureDef(expected_inputs, expected_outputs);
281   AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
282                                     &saved_model_bundle);
283   GraphDef frozen_graph_def;
284   std::unordered_set<string> inputs;
285   std::unordered_set<string> outputs;
286   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
287                                 &outputs));
288   EXPECT_EQ(expected_inputs, inputs);
289   EXPECT_EQ(expected_outputs, outputs);
290 }
291 
TEST_F(FreezeTest,InputsAndOutputsMultipleSignatureDefs)292 TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) {
293   // Test that inputs and outputs get correctly merged and populated when
294   // multiple SignatureDefs are provided.
295   SavedModelBundle saved_model_bundle;
296   SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"});
297   SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"});
298   AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0",
299                                     &saved_model_bundle);
300   AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1",
301                                     &saved_model_bundle);
302   GraphDef frozen_graph_def;
303   std::unordered_set<string> inputs;
304   std::unordered_set<string> outputs;
305   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
306                                 &outputs));
307   std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
308   std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
309   EXPECT_EQ(expected_inputs, inputs);
310   EXPECT_EQ(expected_outputs, outputs);
311 }
312 
TEST_F(FreezeTest,GraphDefVersionsAndLibrary)313 TEST_F(FreezeTest, GraphDefVersionsAndLibrary) {
314   // Test that GraphDef versions and library are copied correctly into the
315   // frozen graph.
316   SavedModelBundle saved_model_bundle;
317   GraphDef graph_def;
318   graph_def.mutable_versions()->set_producer(1234);
319   graph_def.mutable_versions()->set_min_consumer(1234);
320   *graph_def.mutable_library()->add_function() = test::function::NonZero();
321   TF_ASSERT_OK(
322       AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle));
323 
324   GraphDef frozen_graph_def;
325   std::unordered_set<string> inputs;
326   std::unordered_set<string> outputs;
327   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
328                                 &outputs));
329 
330   GraphDefEqual(frozen_graph_def, graph_def);
331 }
332 
TEST_F(FreezeTest,GraphDefWithNoVariables)333 TEST_F(FreezeTest, GraphDefWithNoVariables) {
334   // Test freezing a graph with no variables.
335   SavedModelBundle saved_model_bundle;
336   GraphDef graph_def;
337   Scope scope = Scope::NewRootScope();
338   Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
339   Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
340   Output c = ops::Mul(scope.WithOpName("c"), a, b);
341   TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
342   TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
343                                                         &saved_model_bundle));
344 
345   GraphDef frozen_graph_def;
346   std::unordered_set<string> inputs;
347   std::unordered_set<string> outputs;
348   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
349                                 &outputs));
350 
351   GraphDefEqual(frozen_graph_def, graph_def);
352 }
353 
TEST_F(FreezeTest,GraphDefWithMultiOutputOperation)354 TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
355   // Tensors from operations with multiple outputs get tensor suffixes when used
356   // in input fields of following nodes, i.e. split:0, split:1.
357   // Test that we traverse those correctly.
358   SavedModelBundle saved_model_bundle;
359   GraphDef graph_def;
360   Scope scope = Scope::NewRootScope();
361   Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
362   Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
363   OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
364   Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
365   Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
366   TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
367   TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
368                                                         &saved_model_bundle));
369 
370   GraphDef frozen_graph_def;
371   std::unordered_set<string> inputs;
372   std::unordered_set<string> outputs;
373   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
374                                 &outputs));
375 
376   GraphDefEqual(frozen_graph_def, graph_def);
377 }
378 
TEST_F(FreezeTest,GraphDefWithControlDependency)379 TEST_F(FreezeTest, GraphDefWithControlDependency) {
380   // Inputs that are control dependencies get tensor prefixes,
381   // i.e. ^control_dependency.
382   // Test that we traverse those correctly.
383   SavedModelBundle saved_model_bundle;
384   GraphDef graph_def;
385   Scope scope = Scope::NewRootScope();
386   Output source = ops::Const(scope.WithOpName("source"), 10.0f, {});
387   Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source),
388                         {10.0f, 10.0f}, {2});
389   Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
390   Output c = ops::Mul(scope.WithOpName("c"), a, b);
391   TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
392   TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
393                                                         &saved_model_bundle));
394 
395   GraphDef frozen_graph_def;
396   std::unordered_set<string> inputs;
397   std::unordered_set<string> outputs;
398   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
399                                 &outputs));
400 
401   GraphDefEqual(frozen_graph_def, graph_def);
402 }
403 
TEST_F(FreezeTest,GraphDefWithoutDependentVariables)404 TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
405   TestFreezeGraphWithoutDependentVariables(false);
406 }
407 
TEST_F(FreezeTest,GraphDefWithoutDependentResourceVariables)408 TEST_F(FreezeTest, GraphDefWithoutDependentResourceVariables) {
409   TestFreezeGraphWithoutDependentVariables(true);
410 }
411 
TEST_F(FreezeTest,GraphDefWithDependentVariables)412 TEST_F(FreezeTest, GraphDefWithDependentVariables) {
413   TestFreezeGraphWithDependentVariables(false);
414 }
415 
TEST_F(FreezeTest,GraphDefWithDependentResourceVariables)416 TEST_F(FreezeTest, GraphDefWithDependentResourceVariables) {
417   TestFreezeGraphWithDependentVariables(true);
418 }
419 
TEST_F(FreezeTest,GraphDefWithAndWithoutDependentVariables)420 TEST_F(FreezeTest, GraphDefWithAndWithoutDependentVariables) {
421   TestFreezeGraphWithAndWithoutDependentVariables(false);
422 }
423 
TEST_F(FreezeTest,GraphDefWithAndWithoutDependentResourceVariables)424 TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) {
425   TestFreezeGraphWithAndWithoutDependentVariables(true);
426 }
427 
TEST_F(FreezeTest,InputsAndOutputsCompositeTensorSignatureDef)428 TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) {
429   // Test that inputs and outputs get correctly populated for a
430   // SignatureDef containing composite tensor inputs and outputs.
431   SavedModelBundle saved_model_bundle;
432   SignatureDef signature_def;
433 
434   TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
435   in.mutable_composite_tensor()->add_components()->set_name("input1:0");
436   in.mutable_composite_tensor()->add_components()->set_name("input2:0");
437 
438   TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
439   out.mutable_composite_tensor()->add_components()->set_name("output2:0");
440   out.mutable_composite_tensor()->add_components()->set_name("output1:0");
441 
442   AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
443                                     &saved_model_bundle);
444   GraphDef frozen_graph_def;
445   std::unordered_set<string> inputs;
446   std::unordered_set<string> outputs;
447   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
448                                 &outputs));
449   std::unordered_set<string> expected_inputs = {"input1:0", "input2:0"};
450   std::unordered_set<string> expected_outputs = {"output1:0", "output2:0"};
451   EXPECT_EQ(expected_inputs, inputs);
452   EXPECT_EQ(expected_outputs, outputs);
453 }
454 
TEST_F(FreezeTest,InputsAndOutputsSparseCooSignatureDef)455 TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) {
456   // Test that inputs and outputs get correctly populated for a
457   // SignatureDef containing composite tensor inputs and outputs.
458   SavedModelBundle saved_model_bundle;
459   SignatureDef signature_def;
460 
461   TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
462   in.mutable_coo_sparse()->set_values_tensor_name("input1:0");
463   in.mutable_coo_sparse()->set_indices_tensor_name("input2:0");
464   in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0");
465 
466   TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
467   out.mutable_coo_sparse()->set_values_tensor_name("output1:0");
468   out.mutable_coo_sparse()->set_indices_tensor_name("output2:0");
469   out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0");
470 
471   AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
472                                     &saved_model_bundle);
473   GraphDef frozen_graph_def;
474   std::unordered_set<string> inputs;
475   std::unordered_set<string> outputs;
476   TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
477                                 &outputs));
478   std::unordered_set<string> expected_inputs = {"input1:0", "input2:0",
479                                                 "input3:0"};
480   std::unordered_set<string> expected_outputs = {"output1:0", "output2:0",
481                                                  "output3:0"};
482   EXPECT_EQ(expected_inputs, inputs);
483   EXPECT_EQ(expected_outputs, outputs);
484 }
485 
486 }  // namespace
487 }  // namespace tensorflow
488