1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ 17 #define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ 18 19 #include <vector> 20 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 26 struct StreamingAccuracyStats { StreamingAccuracyStatsStreamingAccuracyStats27 StreamingAccuracyStats() 28 : how_many_ground_truth_words(0), 29 how_many_ground_truth_matched(0), 30 how_many_false_positives(0), 31 how_many_correct_words(0), 32 how_many_wrong_words(0) {} 33 int32 how_many_ground_truth_words; 34 int32 how_many_ground_truth_matched; 35 int32 how_many_false_positives; 36 int32 how_many_correct_words; 37 int32 how_many_wrong_words; 38 }; 39 40 // Takes a file name, and loads a list of expected word labels and times from 41 // it, as comma-separated variables. 42 Status ReadGroundTruthFile(const string& file_name, 43 std::vector<std::pair<string, int64>>* result); 44 45 // Given ground truth labels and corresponding predictions found by a model, 46 // figure out how many were correct. Takes a time limit, so that only 47 // predictions up to a point in time are considered, in case we're evaluating 48 // accuracy when the model has only been run on part of the stream. 49 void CalculateAccuracyStats( 50 const std::vector<std::pair<string, int64>>& ground_truth_list, 51 const std::vector<std::pair<string, int64>>& found_words, 52 int64 up_to_time_ms, int64 time_tolerance_ms, 53 StreamingAccuracyStats* stats); 54 55 // Writes a human-readable description of the statistics to stdout. 56 void PrintAccuracyStats(const StreamingAccuracyStats& stats); 57 58 } // namespace tensorflow 59 60 #endif // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_ 61