1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17
18 #include "tensorflow/cc/ops/resource_variable_ops.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/framework/function_testlib.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/core/public/session_options.h"
28
29 namespace tensorflow {
30 namespace {
31
32 class FreezeTest : public ::testing::Test {
33 protected:
GraphDefEqual(const GraphDef & actual,const GraphDef & expected)34 void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) {
35 EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString());
36 }
37
38 // Builds a SignatureDef with the provided `inputs` and `outputs`.
BuildSignatureDef(const std::unordered_set<string> & inputs,const std::unordered_set<string> & outputs)39 SignatureDef BuildSignatureDef(const std::unordered_set<string>& inputs,
40 const std::unordered_set<string>& outputs) {
41 SignatureDef signature_def;
42 for (const string& input : inputs) {
43 (*signature_def.mutable_inputs())[input].set_name(input);
44 }
45 for (const string& output : outputs) {
46 (*signature_def.mutable_outputs())[output].set_name(output);
47 }
48 return signature_def;
49 }
50
51 // Adds `signature_def` to `saved_model_bundle` under `key`.
AddSignatureDefToSavedModelBundle(const SignatureDef & signature_def,const string & key,SavedModelBundle * saved_model_bundle)52 void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def,
53 const string& key,
54 SavedModelBundle* saved_model_bundle) {
55 MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
56 (*meta_graph_def->mutable_signature_def())[key] = signature_def;
57 }
58
59 // Adds an initialized session to `saved_model_bundle` using `graph_def` and
60 // initializing with `init_node`.
InitializeSavedModelBundleSession(const GraphDef & graph_def,const string & init_node,SavedModelBundle * saved_model_bundle)61 Status InitializeSavedModelBundleSession(
62 const GraphDef& graph_def, const string& init_node,
63 SavedModelBundle* saved_model_bundle) {
64 SessionOptions session_options;
65 saved_model_bundle->session.reset(NewSession(session_options));
66 TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def));
67 if (!init_node.empty()) {
68 std::vector<Tensor> outputs;
69 return saved_model_bundle->session->Run(
70 /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs);
71 }
72 return Status::OK();
73 }
74
75 // Adds `graph_def` to `saved_model_bundle` and initializes a session with
76 // `init_node`.
AddGraphDefToSavedModelBundle(const GraphDef & graph_def,const string & init_node,SavedModelBundle * saved_model_bundle)77 Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def,
78 const string& init_node,
79 SavedModelBundle* saved_model_bundle) {
80 MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
81 *meta_graph_def->mutable_graph_def() = graph_def;
82 return InitializeSavedModelBundleSession(graph_def, init_node,
83 saved_model_bundle);
84 }
85
86 // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in
87 // `saved_model_bundle` and initializes a session with `init_node`.
AddGraphDefWithOutputsToSavedModelBundle(const GraphDef & graph_def,const std::unordered_set<string> & outputs,const string & init_node,SavedModelBundle * saved_model_bundle)88 Status AddGraphDefWithOutputsToSavedModelBundle(
89 const GraphDef& graph_def, const std::unordered_set<string>& outputs,
90 const string& init_node, SavedModelBundle* saved_model_bundle) {
91 SignatureDef signature_def =
92 BuildSignatureDef(std::unordered_set<string>(), outputs);
93 AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
94 saved_model_bundle);
95 return AddGraphDefToSavedModelBundle(graph_def, init_node,
96 saved_model_bundle);
97 }
98
99 // Runs and compares the outputs of `tensor_name` on both the
100 // `unfrozen_session` and the `frozen_graph_def.
RunAndCompareFrozenAndUnfrozenGraphs(Session * unfrozen_session,const GraphDef & frozen_graph_def,const string & tensor_name)101 void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session,
102 const GraphDef& frozen_graph_def,
103 const string& tensor_name) {
104 std::vector<Tensor> unfrozen_outputs;
105 TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name},
106 /* targets */ {}, &unfrozen_outputs));
107
108 SessionOptions session_options;
109 std::unique_ptr<Session> frozen_session(NewSession(session_options));
110 TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
111 std::vector<Tensor> frozen_outputs;
112 TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name},
113 /* targets */ {}, &frozen_outputs));
114
115 test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]);
116 }
117
TestFreezeGraphWithoutDependentVariables(bool use_resource)118 void TestFreezeGraphWithoutDependentVariables(bool use_resource) {
119 // Test freezing a graph with variables that are not needed by the outputs
120 // in the SignatureDef. The resulting graph shouldn't be frozen, but
121 // non-dependent nodes should be pruned.
122 SavedModelBundle saved_model_bundle;
123 GraphDef graph_def;
124 Scope scope = Scope::NewRootScope();
125 Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
126 Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
127 Output c = ops::Mul(scope.WithOpName("c"), a, b);
128 if (use_resource) {
129 Output var =
130 ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
131 Output read_var = ops::ReadVariableOp(
132 scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
133 auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
134 } else {
135 Output var =
136 ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
137 Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
138 }
139
140 TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
141 // "c" isn't dependent on the variable, so nothing should be frozen.
142 TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
143 graph_def, {"c:0"}, "assign", &saved_model_bundle));
144
145 GraphDef frozen_graph_def;
146 std::unordered_set<string> inputs;
147 std::unordered_set<string> outputs;
148 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
149 &inputs, &outputs));
150
151 GraphDef expected_graph_def;
152 Scope expected_scope = Scope::NewRootScope();
153 Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
154 Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
155 Output expected_c =
156 ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
157 TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
158
159 GraphDefEqual(frozen_graph_def, expected_graph_def);
160
161 RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
162 frozen_graph_def, "c:0");
163 }
164
TestFreezeGraphWithDependentVariables(bool use_resource)165 void TestFreezeGraphWithDependentVariables(bool use_resource) {
166 // Test freezing a graph with variables that are needed by outputs in the
167 // SignatureDef. The variables should be frozen.
168 SavedModelBundle saved_model_bundle;
169 GraphDef graph_def;
170 Scope scope = Scope::NewRootScope();
171 Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
172 Output read_var;
173 if (use_resource) {
174 Output var =
175 ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
176 read_var = ops::ReadVariableOp(
177 scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
178 auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
179 } else {
180 Output read_var =
181 ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
182 Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a);
183 }
184 Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
185 TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
186 // "c" isn't dependent on the variable, so nothing should be frozen.
187 TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
188 graph_def, {"c:0"}, "assign", &saved_model_bundle));
189
190 GraphDef frozen_graph_def;
191 std::unordered_set<string> inputs;
192 std::unordered_set<string> outputs;
193 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
194 &inputs, &outputs));
195
196 // If using normal variables there should be 3 nodes in the resulting
197 // graph_def. If using resource variables there should be 4 nodes in the
198 // resulting graph_def.
199 // In both cases, none should be variables.
200 size_t expected_nodes = use_resource ? 4 : 3;
201 EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes);
202 for (const NodeDef& node : frozen_graph_def.node()) {
203 EXPECT_NE(node.op(), "Variable") << node.name();
204 EXPECT_NE(node.op(), "VariableV2") << node.name();
205 EXPECT_NE(node.op(), "VarHandleOp") << node.name();
206 EXPECT_NE(node.op(), "ReadVariableOp") << node.name();
207 }
208
209 RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
210 frozen_graph_def, "c:0");
211 }
212
TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource)213 void TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource) {
214 // Test freezing a graph with some variables that are needed and not needed
215 // by
216 // the outputs in the SignatureDef. The resulting graph should only freeze
217 // dependent variables.
218 SavedModelBundle saved_model_bundle;
219 GraphDef graph_def;
220 Scope scope = Scope::NewRootScope();
221 Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
222 Output read_var;
223
224 if (use_resource) {
225 Output var =
226 ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {});
227 read_var = ops::ReadVariableOp(
228 scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT);
229 auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a);
230 Output var_1 =
231 ops::VarHandleOp(scope.WithOpName("var_1"), DataType::DT_FLOAT, {});
232 Output read_var_1 =
233 ops::ReadVariableOp(scope.WithOpName("var_1/Read/ReadVariableOp"),
234 var, DataType::DT_FLOAT);
235 auto assign_1 =
236 ops::AssignVariableOp(scope.WithOpName("assign_1"), var_1, a);
237 } else {
238 read_var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
239 Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a);
240 Output var_1 =
241 ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
242 Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var_1, a);
243 }
244
245 Output c = ops::Mul(scope.WithOpName("c"), a, read_var);
246 TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
247 // "c" isn't dependent on the variable, so nothing should be frozen.
248 TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
249 graph_def, {"c:0"}, "assign", &saved_model_bundle));
250
251 GraphDef frozen_graph_def;
252 std::unordered_set<string> inputs;
253 std::unordered_set<string> outputs;
254 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def,
255 &inputs, &outputs));
256
257 // There should be 3 nodes in the resulting graph_def, and none should be
258 // variables.
259 size_t expected_nodes = use_resource ? 4 : 3;
260 EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes);
261 for (const NodeDef& node : frozen_graph_def.node()) {
262 EXPECT_NE(node.op(), "Variable") << node.name();
263 EXPECT_NE(node.op(), "VariableV2") << node.name();
264 EXPECT_NE(node.op(), "VarHandleOp") << node.name();
265 EXPECT_NE(node.op(), "ReadVariableOp") << node.name();
266 }
267
268 RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
269 frozen_graph_def, "c:0");
270 }
271 };
272
TEST_F(FreezeTest,InputsAndOutputsSingleSignatureDef)273 TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) {
274 // Test that inputs and outputs get correctly populated for a single
275 // SignatureDef.
276 SavedModelBundle saved_model_bundle;
277 std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
278 std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
279 SignatureDef signature_def =
280 BuildSignatureDef(expected_inputs, expected_outputs);
281 AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
282 &saved_model_bundle);
283 GraphDef frozen_graph_def;
284 std::unordered_set<string> inputs;
285 std::unordered_set<string> outputs;
286 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
287 &outputs));
288 EXPECT_EQ(expected_inputs, inputs);
289 EXPECT_EQ(expected_outputs, outputs);
290 }
291
TEST_F(FreezeTest,InputsAndOutputsMultipleSignatureDefs)292 TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) {
293 // Test that inputs and outputs get correctly merged and populated when
294 // multiple SignatureDefs are provided.
295 SavedModelBundle saved_model_bundle;
296 SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"});
297 SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"});
298 AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0",
299 &saved_model_bundle);
300 AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1",
301 &saved_model_bundle);
302 GraphDef frozen_graph_def;
303 std::unordered_set<string> inputs;
304 std::unordered_set<string> outputs;
305 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
306 &outputs));
307 std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
308 std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
309 EXPECT_EQ(expected_inputs, inputs);
310 EXPECT_EQ(expected_outputs, outputs);
311 }
312
TEST_F(FreezeTest,GraphDefVersionsAndLibrary)313 TEST_F(FreezeTest, GraphDefVersionsAndLibrary) {
314 // Test that GraphDef versions and library are copied correctly into the
315 // frozen graph.
316 SavedModelBundle saved_model_bundle;
317 GraphDef graph_def;
318 graph_def.mutable_versions()->set_producer(1234);
319 graph_def.mutable_versions()->set_min_consumer(1234);
320 *graph_def.mutable_library()->add_function() = test::function::NonZero();
321 TF_ASSERT_OK(
322 AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle));
323
324 GraphDef frozen_graph_def;
325 std::unordered_set<string> inputs;
326 std::unordered_set<string> outputs;
327 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
328 &outputs));
329
330 GraphDefEqual(frozen_graph_def, graph_def);
331 }
332
TEST_F(FreezeTest,GraphDefWithNoVariables)333 TEST_F(FreezeTest, GraphDefWithNoVariables) {
334 // Test freezing a graph with no variables.
335 SavedModelBundle saved_model_bundle;
336 GraphDef graph_def;
337 Scope scope = Scope::NewRootScope();
338 Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
339 Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
340 Output c = ops::Mul(scope.WithOpName("c"), a, b);
341 TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
342 TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
343 &saved_model_bundle));
344
345 GraphDef frozen_graph_def;
346 std::unordered_set<string> inputs;
347 std::unordered_set<string> outputs;
348 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
349 &outputs));
350
351 GraphDefEqual(frozen_graph_def, graph_def);
352 }
353
TEST_F(FreezeTest,GraphDefWithMultiOutputOperation)354 TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
355 // Tensors from operations with multiple outputs get tensor suffixes when used
356 // in input fields of following nodes, i.e. split:0, split:1.
357 // Test that we traverse those correctly.
358 SavedModelBundle saved_model_bundle;
359 GraphDef graph_def;
360 Scope scope = Scope::NewRootScope();
361 Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
362 Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
363 OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
364 Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
365 Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
366 TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
367 TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
368 &saved_model_bundle));
369
370 GraphDef frozen_graph_def;
371 std::unordered_set<string> inputs;
372 std::unordered_set<string> outputs;
373 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
374 &outputs));
375
376 GraphDefEqual(frozen_graph_def, graph_def);
377 }
378
TEST_F(FreezeTest,GraphDefWithControlDependency)379 TEST_F(FreezeTest, GraphDefWithControlDependency) {
380 // Inputs that are control dependencies get tensor prefixes,
381 // i.e. ^control_dependency.
382 // Test that we traverse those correctly.
383 SavedModelBundle saved_model_bundle;
384 GraphDef graph_def;
385 Scope scope = Scope::NewRootScope();
386 Output source = ops::Const(scope.WithOpName("source"), 10.0f, {});
387 Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source),
388 {10.0f, 10.0f}, {2});
389 Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
390 Output c = ops::Mul(scope.WithOpName("c"), a, b);
391 TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
392 TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
393 &saved_model_bundle));
394
395 GraphDef frozen_graph_def;
396 std::unordered_set<string> inputs;
397 std::unordered_set<string> outputs;
398 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
399 &outputs));
400
401 GraphDefEqual(frozen_graph_def, graph_def);
402 }
403
TEST_F(FreezeTest,GraphDefWithoutDependentVariables)404 TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
405 TestFreezeGraphWithoutDependentVariables(false);
406 }
407
TEST_F(FreezeTest,GraphDefWithoutDependentResourceVariables)408 TEST_F(FreezeTest, GraphDefWithoutDependentResourceVariables) {
409 TestFreezeGraphWithoutDependentVariables(true);
410 }
411
TEST_F(FreezeTest,GraphDefWithDependentVariables)412 TEST_F(FreezeTest, GraphDefWithDependentVariables) {
413 TestFreezeGraphWithDependentVariables(false);
414 }
415
TEST_F(FreezeTest,GraphDefWithDependentResourceVariables)416 TEST_F(FreezeTest, GraphDefWithDependentResourceVariables) {
417 TestFreezeGraphWithDependentVariables(true);
418 }
419
TEST_F(FreezeTest,GraphDefWithAndWithoutDependentVariables)420 TEST_F(FreezeTest, GraphDefWithAndWithoutDependentVariables) {
421 TestFreezeGraphWithAndWithoutDependentVariables(false);
422 }
423
TEST_F(FreezeTest,GraphDefWithAndWithoutDependentResourceVariables)424 TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) {
425 TestFreezeGraphWithAndWithoutDependentVariables(true);
426 }
427
TEST_F(FreezeTest,InputsAndOutputsCompositeTensorSignatureDef)428 TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) {
429 // Test that inputs and outputs get correctly populated for a
430 // SignatureDef containing composite tensor inputs and outputs.
431 SavedModelBundle saved_model_bundle;
432 SignatureDef signature_def;
433
434 TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
435 in.mutable_composite_tensor()->add_components()->set_name("input1:0");
436 in.mutable_composite_tensor()->add_components()->set_name("input2:0");
437
438 TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
439 out.mutable_composite_tensor()->add_components()->set_name("output2:0");
440 out.mutable_composite_tensor()->add_components()->set_name("output1:0");
441
442 AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
443 &saved_model_bundle);
444 GraphDef frozen_graph_def;
445 std::unordered_set<string> inputs;
446 std::unordered_set<string> outputs;
447 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
448 &outputs));
449 std::unordered_set<string> expected_inputs = {"input1:0", "input2:0"};
450 std::unordered_set<string> expected_outputs = {"output1:0", "output2:0"};
451 EXPECT_EQ(expected_inputs, inputs);
452 EXPECT_EQ(expected_outputs, outputs);
453 }
454
TEST_F(FreezeTest,InputsAndOutputsSparseCooSignatureDef)455 TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) {
456 // Test that inputs and outputs get correctly populated for a
457 // SignatureDef containing composite tensor inputs and outputs.
458 SavedModelBundle saved_model_bundle;
459 SignatureDef signature_def;
460
461 TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
462 in.mutable_coo_sparse()->set_values_tensor_name("input1:0");
463 in.mutable_coo_sparse()->set_indices_tensor_name("input2:0");
464 in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0");
465
466 TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
467 out.mutable_coo_sparse()->set_values_tensor_name("output1:0");
468 out.mutable_coo_sparse()->set_indices_tensor_name("output2:0");
469 out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0");
470
471 AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
472 &saved_model_bundle);
473 GraphDef frozen_graph_def;
474 std::unordered_set<string> inputs;
475 std::unordered_set<string> outputs;
476 TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
477 &outputs));
478 std::unordered_set<string> expected_inputs = {"input1:0", "input2:0",
479 "input3:0"};
480 std::unordered_set<string> expected_outputs = {"output1:0", "output2:0",
481 "output3:0"};
482 EXPECT_EQ(expected_inputs, inputs);
483 EXPECT_EQ(expected_outputs, outputs);
484 }
485
486 } // namespace
487 } // namespace tensorflow
488