• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/accuracy/eval_pipeline.h"
17 #include <gtest/gtest.h>
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/public/session.h"
20 
21 namespace tensorflow {
22 namespace metrics {
23 namespace {
24 
CreateFloatTensor(float value)25 Tensor CreateFloatTensor(float value) {
26   Tensor tensor(DT_FLOAT, TensorShape({}));
27   tensor.scalar<float>()() = value;
28   return tensor;
29 }
30 
31 class NoOpAccuracyEval : public AccuracyEval {
32  public:
NoOpAccuracyEval(const Status & status_to_return)33   explicit NoOpAccuracyEval(const Status& status_to_return)
34       : status_to_return_(status_to_return) {}
35 
ComputeEval(const std::vector<Tensor> & model_outputs,const Tensor & ground_truth)36   Status ComputeEval(const std::vector<Tensor>& model_outputs,
37                      const Tensor& ground_truth) override {
38     model_outputs_ = model_outputs;
39     ground_truth_ = ground_truth;
40     was_called_ = true;
41     return status_to_return_;
42   }
43 
WasCalled()44   bool WasCalled() { return was_called_; }
model_outputs()45   std::vector<Tensor> model_outputs() { return model_outputs_; }
ground_truth()46   Tensor ground_truth() { return ground_truth_; }
47 
48  private:
49   std::vector<Tensor> model_outputs_;
50   Tensor ground_truth_;
51   Status status_to_return_;
52   bool was_called_ = false;
53 };
54 
TEST(EvalPipeline,AccuracyEvalIsCalled)55 TEST(EvalPipeline, AccuracyEvalIsCalled) {
56   Scope scope = Scope::NewRootScope();
57   // A graph that adds 1 to input.
58   auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
59   auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
60   GraphDef graph_def;
61   TF_CHECK_OK(scope.ToGraphDef(&graph_def));
62   EvalPipeline::Params params;
63   params.model_input_node_name = "input";
64   params.model_output_node_name = "output";
65   NoOpAccuracyEval accuracy_eval(Status::OK());
66 
67   EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
68   std::unique_ptr<Session> session(NewSession(SessionOptions()));
69   TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
70   TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)));
71 
72   EXPECT_TRUE(accuracy_eval.WasCalled());
73   auto outputs = accuracy_eval.model_outputs();
74   ASSERT_EQ(1, outputs.size());
75   EXPECT_EQ(6.0f, outputs[0].scalar<float>()());
76   // Ground truth is unchanged.
77   EXPECT_EQ(27, accuracy_eval.ground_truth().scalar<float>()());
78 }
79 
TEST(EvalPipeline,EvalIsNotCalledOnGraphRunFailure)80 TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) {
81   Scope scope = Scope::NewRootScope();
82   // A graph that adds 1 to input.
83   auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
84   auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
85   GraphDef graph_def;
86   TF_CHECK_OK(scope.ToGraphDef(&graph_def));
87   EvalPipeline::Params params;
88   params.model_input_node_name = "input";
89   params.model_output_node_name = "output";
90   NoOpAccuracyEval accuracy_eval(Status::OK());
91 
92   EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
93   std::unique_ptr<Session> session(NewSession(SessionOptions()));
94   TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
95 
96   // Pass a string tensor instead of a float tensor.
97   Tensor string_tensor(DT_STRING, TensorShape{});
98   auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27));
99   EXPECT_FALSE(accuracy_eval.WasCalled());
100   EXPECT_FALSE(status.ok());
101 }
102 
TEST(EvalPipeline,AccuracyEvalFailureResultsInFailure)103 TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) {
104   Scope scope = Scope::NewRootScope();
105   // A graph that adds 1 to input.
106   auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
107   auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
108   GraphDef graph_def;
109   TF_CHECK_OK(scope.ToGraphDef(&graph_def));
110   EvalPipeline::Params params;
111   params.model_input_node_name = "input";
112   params.model_output_node_name = "output";
113   NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail"));
114 
115   EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
116   std::unique_ptr<Session> session(NewSession(SessionOptions()));
117   TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
118   auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27));
119 
120   EXPECT_TRUE(accuracy_eval.WasCalled());
121   EXPECT_FALSE(status.ok());
122 }
123 
124 }  // namespace
125 
126 }  // namespace metrics
127 }  // namespace tensorflow
128 
main(int argc,char ** argv)129 int main(int argc, char** argv) {
130   ::testing::InitGoogleTest(&argc, argv);
131 
132   return RUN_ALL_TESTS();
133 }
134