1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/examples/speech_commands/recognize_commands.h"
17
18 namespace tensorflow {
19
RecognizeCommands(const std::vector<string> & labels,int32_t average_window_duration_ms,float detection_threshold,int32_t suppression_ms,int32_t minimum_count)20 RecognizeCommands::RecognizeCommands(const std::vector<string>& labels,
21 int32_t average_window_duration_ms,
22 float detection_threshold,
23 int32_t suppression_ms,
24 int32_t minimum_count)
25 : labels_(labels),
26 average_window_duration_ms_(average_window_duration_ms),
27 detection_threshold_(detection_threshold),
28 suppression_ms_(suppression_ms),
29 minimum_count_(minimum_count) {
30 labels_count_ = labels.size();
31 previous_top_label_ = "_silence_";
32 previous_top_label_time_ = std::numeric_limits<int64_t>::min();
33 }
34
ProcessLatestResults(const Tensor & latest_results,const int64_t current_time_ms,string * found_command,float * score,bool * is_new_command)35 Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results,
36 const int64_t current_time_ms,
37 string* found_command,
38 float* score,
39 bool* is_new_command) {
40 if (latest_results.NumElements() != labels_count_) {
41 return errors::InvalidArgument(
42 "The results for recognition should contain ", labels_count_,
43 " elements, but there are ", latest_results.NumElements());
44 }
45
46 if ((!previous_results_.empty()) &&
47 (current_time_ms < previous_results_.front().first)) {
48 return errors::InvalidArgument(
49 "Results must be fed in increasing time order, but received a "
50 "timestamp of ",
51 current_time_ms, " that was earlier than the previous one of ",
52 previous_results_.front().first);
53 }
54
55 // Add the latest results to the head of the queue.
56 previous_results_.push_back({current_time_ms, latest_results});
57
58 // Prune any earlier results that are too old for the averaging window.
59 const int64_t time_limit = current_time_ms - average_window_duration_ms_;
60 while (previous_results_.front().first < time_limit) {
61 previous_results_.pop_front();
62 }
63
64 // If there are too few results, assume the result will be unreliable and
65 // bail.
66 const int64_t how_many_results = previous_results_.size();
67 const int64_t earliest_time = previous_results_.front().first;
68 const int64_t samples_duration = current_time_ms - earliest_time;
69 if ((how_many_results < minimum_count_) ||
70 (samples_duration < (average_window_duration_ms_ / 4))) {
71 *found_command = previous_top_label_;
72 *score = 0.0f;
73 *is_new_command = false;
74 return OkStatus();
75 }
76
77 // Calculate the average score across all the results in the window.
78 std::vector<float> average_scores(labels_count_);
79 for (const auto& previous_result : previous_results_) {
80 const Tensor& scores_tensor = previous_result.second;
81 auto scores_flat = scores_tensor.flat<float>();
82 for (int i = 0; i < scores_flat.size(); ++i) {
83 average_scores[i] += scores_flat(i) / how_many_results;
84 }
85 }
86
87 // Sort the averaged results in descending score order.
88 std::vector<std::pair<int, float>> sorted_average_scores;
89 sorted_average_scores.reserve(labels_count_);
90 for (int i = 0; i < labels_count_; ++i) {
91 sorted_average_scores.push_back(
92 std::pair<int, float>({i, average_scores[i]}));
93 }
94 std::sort(sorted_average_scores.begin(), sorted_average_scores.end(),
95 [](const std::pair<int, float>& left,
96 const std::pair<int, float>& right) {
97 return left.second > right.second;
98 });
99
100 // See if the latest top score is enough to trigger a detection.
101 const int current_top_index = sorted_average_scores[0].first;
102 const string current_top_label = labels_[current_top_index];
103 const float current_top_score = sorted_average_scores[0].second;
104 // If we've recently had another label trigger, assume one that occurs too
105 // soon afterwards is a bad result.
106 int64_t time_since_last_top;
107 if ((previous_top_label_ == "_silence_") ||
108 (previous_top_label_time_ == std::numeric_limits<int64_t>::min())) {
109 time_since_last_top = std::numeric_limits<int64_t>::max();
110 } else {
111 time_since_last_top = current_time_ms - previous_top_label_time_;
112 }
113 if ((current_top_score > detection_threshold_) &&
114 (current_top_label != previous_top_label_) &&
115 (time_since_last_top > suppression_ms_)) {
116 previous_top_label_ = current_top_label;
117 previous_top_label_time_ = current_time_ms;
118 *is_new_command = true;
119 } else {
120 *is_new_command = false;
121 }
122 *found_command = current_top_label;
123 *score = current_top_score;
124
125 return OkStatus();
126 }
127
128 } // namespace tensorflow
129