1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/tools/accuracy/eval_pipeline.h"
17 #include <gtest/gtest.h>
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/public/session.h"
20
21 namespace tensorflow {
22 namespace metrics {
23 namespace {
24
CreateFloatTensor(float value)25 Tensor CreateFloatTensor(float value) {
26 Tensor tensor(DT_FLOAT, TensorShape({}));
27 tensor.scalar<float>()() = value;
28 return tensor;
29 }
30
31 class NoOpAccuracyEval : public AccuracyEval {
32 public:
NoOpAccuracyEval(const Status & status_to_return)33 explicit NoOpAccuracyEval(const Status& status_to_return)
34 : status_to_return_(status_to_return) {}
35
ComputeEval(const std::vector<Tensor> & model_outputs,const Tensor & ground_truth)36 Status ComputeEval(const std::vector<Tensor>& model_outputs,
37 const Tensor& ground_truth) override {
38 model_outputs_ = model_outputs;
39 ground_truth_ = ground_truth;
40 was_called_ = true;
41 return status_to_return_;
42 }
43
WasCalled()44 bool WasCalled() { return was_called_; }
model_outputs()45 std::vector<Tensor> model_outputs() { return model_outputs_; }
ground_truth()46 Tensor ground_truth() { return ground_truth_; }
47
48 private:
49 std::vector<Tensor> model_outputs_;
50 Tensor ground_truth_;
51 Status status_to_return_;
52 bool was_called_ = false;
53 };
54
TEST(EvalPipeline,AccuracyEvalIsCalled)55 TEST(EvalPipeline, AccuracyEvalIsCalled) {
56 Scope scope = Scope::NewRootScope();
57 // A graph that adds 1 to input.
58 auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
59 auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
60 GraphDef graph_def;
61 TF_CHECK_OK(scope.ToGraphDef(&graph_def));
62 EvalPipeline::Params params;
63 params.model_input_node_name = "input";
64 params.model_output_node_name = "output";
65 NoOpAccuracyEval accuracy_eval(Status::OK());
66
67 EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
68 std::unique_ptr<Session> session(NewSession(SessionOptions()));
69 TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
70 TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)));
71
72 EXPECT_TRUE(accuracy_eval.WasCalled());
73 auto outputs = accuracy_eval.model_outputs();
74 ASSERT_EQ(1, outputs.size());
75 EXPECT_EQ(6.0f, outputs[0].scalar<float>()());
76 // Ground truth is unchanged.
77 EXPECT_EQ(27, accuracy_eval.ground_truth().scalar<float>()());
78 }
79
TEST(EvalPipeline,EvalIsNotCalledOnGraphRunFailure)80 TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) {
81 Scope scope = Scope::NewRootScope();
82 // A graph that adds 1 to input.
83 auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
84 auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
85 GraphDef graph_def;
86 TF_CHECK_OK(scope.ToGraphDef(&graph_def));
87 EvalPipeline::Params params;
88 params.model_input_node_name = "input";
89 params.model_output_node_name = "output";
90 NoOpAccuracyEval accuracy_eval(Status::OK());
91
92 EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
93 std::unique_ptr<Session> session(NewSession(SessionOptions()));
94 TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
95
96 // Pass a string tensor instead of a float tensor.
97 Tensor string_tensor(DT_STRING, TensorShape{});
98 auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27));
99 EXPECT_FALSE(accuracy_eval.WasCalled());
100 EXPECT_FALSE(status.ok());
101 }
102
TEST(EvalPipeline,AccuracyEvalFailureResultsInFailure)103 TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) {
104 Scope scope = Scope::NewRootScope();
105 // A graph that adds 1 to input.
106 auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
107 auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
108 GraphDef graph_def;
109 TF_CHECK_OK(scope.ToGraphDef(&graph_def));
110 EvalPipeline::Params params;
111 params.model_input_node_name = "input";
112 params.model_output_node_name = "output";
113 NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail"));
114
115 EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
116 std::unique_ptr<Session> session(NewSession(SessionOptions()));
117 TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
118 auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27));
119
120 EXPECT_TRUE(accuracy_eval.WasCalled());
121 EXPECT_FALSE(status.ok());
122 }
123
124 } // namespace
125
126 } // namespace metrics
127 } // namespace tensorflow
128
main(int argc,char ** argv)129 int main(int argc, char** argv) {
130 ::testing::InitGoogleTest(&argc, argv);
131
132 return RUN_ALL_TESTS();
133 }
134