• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <array>
17 #include <string>
18 #include <vector>
19 
20 #include <gtest/gtest.h>
21 #include "tensorflow/lite/c/c_api.h"
22 #include "tensorflow/lite/c/c_api_experimental.h"
23 
24 namespace tflite {
25 namespace {
26 
TEST(SignatureRunnerTest,TestMultiSignatures)27 TEST(SignatureRunnerTest, TestMultiSignatures) {
28   TfLiteModel* model = TfLiteModelCreateFromFile(
29       "tensorflow/lite/testdata/multi_signatures.bin");
30   ASSERT_NE(model, nullptr);
31 
32   TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
33   ASSERT_NE(options, nullptr);
34   TfLiteInterpreterOptionsSetNumThreads(options, 2);
35 
36   TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
37   ASSERT_NE(interpreter, nullptr);
38 
39   // The options can be deleted immediately after interpreter creation.
40   TfLiteInterpreterOptionsDelete(options);
41 
42   std::vector<std::string> signature_defs;
43   for (int i = 0; i < TfLiteInterpreterGetSignatureCount(interpreter); i++) {
44     signature_defs.push_back(TfLiteInterpreterGetSignatureKey(interpreter, i));
45   }
46   ASSERT_EQ(signature_defs.size(), 2);
47   ASSERT_EQ(signature_defs[0], "add");
48   ASSERT_EQ(signature_defs[1], "sub");
49   ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr);
50 
51   TfLiteSignatureRunner* add_runner = TfLiteInterpreterGetSignatureRunner(
52       interpreter, signature_defs[0].c_str());
53   ASSERT_NE(add_runner, nullptr);
54   std::vector<const char*> input_names;
55   for (int i = 0; i < TfLiteSignatureRunnerGetInputCount(add_runner); i++) {
56     input_names.push_back(TfLiteSignatureRunnerGetInputName(add_runner, i));
57   }
58   std::vector<const char*> output_names;
59   for (int i = 0; i < TfLiteSignatureRunnerGetOutputCount(add_runner); i++) {
60     output_names.push_back(TfLiteSignatureRunnerGetOutputName(add_runner, i));
61   }
62   ASSERT_EQ(input_names.size(), 1);
63   ASSERT_EQ(std::string(input_names[0]), "x");
64   ASSERT_EQ(output_names.size(), 1);
65   ASSERT_EQ(std::string(output_names[0]), "output_0");
66   std::array<int, 1> add_runner_input_dims{2};
67   ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
68                 add_runner, "x", add_runner_input_dims.data(),
69                 add_runner_input_dims.size()),
70             kTfLiteOk);
71   ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(add_runner), kTfLiteOk);
72   TfLiteTensor* add_input =
73       TfLiteSignatureRunnerGetInputTensor(add_runner, "x");
74   ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(add_runner, "foo"), nullptr);
75   const TfLiteTensor* add_output =
76       TfLiteSignatureRunnerGetOutputTensor(add_runner, "output_0");
77   ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(add_runner, "foo"), nullptr);
78   ASSERT_NE(add_input, nullptr);
79   ASSERT_NE(add_output, nullptr);
80   add_input->data.f[0] = 2;
81   add_input->data.f[1] = 4;
82   ASSERT_EQ(TfLiteSignatureRunnerInvoke(add_runner), kTfLiteOk);
83   ASSERT_EQ(add_output->data.f[0], 4);
84   ASSERT_EQ(add_output->data.f[1], 6);
85   TfLiteSignatureRunnerDelete(add_runner);
86 
87   TfLiteSignatureRunner* sub_runner =
88       TfLiteInterpreterGetSignatureRunner(interpreter, "sub");
89   ASSERT_NE(sub_runner, nullptr);
90   std::vector<const char*> input_names2;
91   for (int i = 0; i < TfLiteSignatureRunnerGetInputCount(sub_runner); i++) {
92     input_names2.push_back(TfLiteSignatureRunnerGetInputName(sub_runner, i));
93   }
94   std::vector<const char*> output_names2;
95   for (int i = 0; i < TfLiteSignatureRunnerGetOutputCount(sub_runner); i++) {
96     output_names2.push_back(TfLiteSignatureRunnerGetOutputName(sub_runner, i));
97   }
98   ASSERT_EQ(input_names2.size(), 1);
99   ASSERT_EQ(std::string(input_names2[0]), "x");
100   ASSERT_EQ(output_names2.size(), 1);
101   ASSERT_EQ(std::string(output_names2[0]), "output_0");
102   std::array<int, 1> sub_runner_input_dims{3};
103   ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor(
104                 sub_runner, "x", sub_runner_input_dims.data(),
105                 sub_runner_input_dims.size()),
106             kTfLiteOk);
107   ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(sub_runner), kTfLiteOk);
108   TfLiteTensor* sub_input =
109       TfLiteSignatureRunnerGetInputTensor(sub_runner, "x");
110   ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(sub_runner, "foo"), nullptr);
111   const TfLiteTensor* sub_output =
112       TfLiteSignatureRunnerGetOutputTensor(sub_runner, "output_0");
113   ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(sub_runner, "foo"), nullptr);
114   ASSERT_NE(sub_input, nullptr);
115   ASSERT_NE(sub_output, nullptr);
116   sub_input->data.f[0] = 2;
117   sub_input->data.f[1] = 4;
118   sub_input->data.f[2] = 6;
119   ASSERT_EQ(TfLiteSignatureRunnerInvoke(sub_runner), kTfLiteOk);
120   ASSERT_EQ(sub_output->data.f[0], -1);
121   ASSERT_EQ(sub_output->data.f[1], 1);
122   ASSERT_EQ(sub_output->data.f[2], 3);
123   TfLiteSignatureRunnerDelete(sub_runner);
124 
125   TfLiteInterpreterDelete(interpreter);
126   TfLiteModelDelete(model);
127 }
128 
129 }  // namespace
130 }  // namespace tflite
131