1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <array>
17 #include <string>
18 #include <vector>
19
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/c/c_api.h"
22 #include "tensorflow/lite/c/c_api_experimental.h"
23
24 namespace tflite {
25 namespace {
26
TEST(SignatureRunnerTest,TestMultiSignatures)27 TEST(SignatureRunnerTest, TestMultiSignatures) {
28 TfLiteModel* model = TfLiteModelCreateFromFile(
29 "tensorflow/lite/testdata/multi_signatures.bin");
30 ASSERT_NE(model, nullptr);
31
32 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
33 ASSERT_NE(options, nullptr);
34 TfLiteInterpreterOptionsSetNumThreads(options, 2);
35
36 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
37 ASSERT_NE(interpreter, nullptr);
38
39 // The options can be deleted immediately after interpreter creation.
40 TfLiteInterpreterOptionsDelete(options);
41
42 std::vector<std::string> signature_defs;
43 for (int i = 0; i < TfLiteInterpreterGetSignatureCount(interpreter); i++) {
44 signature_defs.push_back(TfLiteInterpreterGetSignatureKey(interpreter, i));
45 }
46 ASSERT_EQ(signature_defs.size(), 2);
47 ASSERT_EQ(signature_defs[0], "add");
48 ASSERT_EQ(signature_defs[1], "sub");
49 ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr);
50
51 TfLiteSignatureRunner* add_runner = TfLiteInterpreterGetSignatureRunner(
52 interpreter, signature_defs[0].c_str());
53 ASSERT_NE(add_runner, nullptr);
54 std::vector<const char*> input_names;
55 for (int i = 0; i < TfLiteSignatureRunnerGetInputCount(add_runner); i++) {
56 input_names.push_back(TfLiteSignatureRunnerGetInputName(add_runner, i));
57 }
58 std::vector<const char*> output_names;
59 for (int i = 0; i < TfLiteSignatureRunnerGetOutputCount(add_runner); i++) {
60 output_names.push_back(TfLiteSignatureRunnerGetOutputName(add_runner, i));
61 }
62 ASSERT_EQ(input_names.size(), 1);
63 ASSERT_EQ(std::string(input_names[0]), "x");
64 ASSERT_EQ(output_names.size(), 1);
65 ASSERT_EQ(std::string(output_names[0]), "output_0");
66 std::array<int, 1> add_runner_input_dims{2};
67 ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
68 add_runner, "x", add_runner_input_dims.data(),
69 add_runner_input_dims.size()),
70 kTfLiteOk);
71 ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(add_runner), kTfLiteOk);
72 TfLiteTensor* add_input =
73 TfLiteSignatureRunnerGetInputTensor(add_runner, "x");
74 ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(add_runner, "foo"), nullptr);
75 const TfLiteTensor* add_output =
76 TfLiteSignatureRunnerGetOutputTensor(add_runner, "output_0");
77 ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(add_runner, "foo"), nullptr);
78 ASSERT_NE(add_input, nullptr);
79 ASSERT_NE(add_output, nullptr);
80 add_input->data.f[0] = 2;
81 add_input->data.f[1] = 4;
82 ASSERT_EQ(TfLiteSignatureRunnerInvoke(add_runner), kTfLiteOk);
83 ASSERT_EQ(add_output->data.f[0], 4);
84 ASSERT_EQ(add_output->data.f[1], 6);
85 TfLiteSignatureRunnerDelete(add_runner);
86
87 TfLiteSignatureRunner* sub_runner =
88 TfLiteInterpreterGetSignatureRunner(interpreter, "sub");
89 ASSERT_NE(sub_runner, nullptr);
90 std::vector<const char*> input_names2;
91 for (int i = 0; i < TfLiteSignatureRunnerGetInputCount(sub_runner); i++) {
92 input_names2.push_back(TfLiteSignatureRunnerGetInputName(sub_runner, i));
93 }
94 std::vector<const char*> output_names2;
95 for (int i = 0; i < TfLiteSignatureRunnerGetOutputCount(sub_runner); i++) {
96 output_names2.push_back(TfLiteSignatureRunnerGetOutputName(sub_runner, i));
97 }
98 ASSERT_EQ(input_names2.size(), 1);
99 ASSERT_EQ(std::string(input_names2[0]), "x");
100 ASSERT_EQ(output_names2.size(), 1);
101 ASSERT_EQ(std::string(output_names2[0]), "output_0");
102 std::array<int, 1> sub_runner_input_dims{3};
103 ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
104 sub_runner, "x", sub_runner_input_dims.data(),
105 sub_runner_input_dims.size()),
106 kTfLiteOk);
107 ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(sub_runner), kTfLiteOk);
108 TfLiteTensor* sub_input =
109 TfLiteSignatureRunnerGetInputTensor(sub_runner, "x");
110 ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(sub_runner, "foo"), nullptr);
111 const TfLiteTensor* sub_output =
112 TfLiteSignatureRunnerGetOutputTensor(sub_runner, "output_0");
113 ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(sub_runner, "foo"), nullptr);
114 ASSERT_NE(sub_input, nullptr);
115 ASSERT_NE(sub_output, nullptr);
116 sub_input->data.f[0] = 2;
117 sub_input->data.f[1] = 4;
118 sub_input->data.f[2] = 6;
119 ASSERT_EQ(TfLiteSignatureRunnerInvoke(sub_runner), kTfLiteOk);
120 ASSERT_EQ(sub_output->data.f[0], -1);
121 ASSERT_EQ(sub_output->data.f[1], 1);
122 ASSERT_EQ(sub_output->data.f[2], 3);
123 TfLiteSignatureRunnerDelete(sub_runner);
124
125 TfLiteInterpreterDelete(interpreter);
126 TfLiteModelDelete(model);
127 }
128
129 } // namespace
130 } // namespace tflite
131