• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/examples/speech_commands/recognize_commands.h"
17 
18 namespace tensorflow {
19 
RecognizeCommands(const std::vector<string> & labels,int32_t average_window_duration_ms,float detection_threshold,int32_t suppression_ms,int32_t minimum_count)20 RecognizeCommands::RecognizeCommands(const std::vector<string>& labels,
21                                      int32_t average_window_duration_ms,
22                                      float detection_threshold,
23                                      int32_t suppression_ms,
24                                      int32_t minimum_count)
25     : labels_(labels),
26       average_window_duration_ms_(average_window_duration_ms),
27       detection_threshold_(detection_threshold),
28       suppression_ms_(suppression_ms),
29       minimum_count_(minimum_count) {
30   labels_count_ = labels.size();
31   previous_top_label_ = "_silence_";
32   previous_top_label_time_ = std::numeric_limits<int64_t>::min();
33 }
34 
ProcessLatestResults(const Tensor & latest_results,const int64_t current_time_ms,string * found_command,float * score,bool * is_new_command)35 Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results,
36                                                const int64_t current_time_ms,
37                                                string* found_command,
38                                                float* score,
39                                                bool* is_new_command) {
40   if (latest_results.NumElements() != labels_count_) {
41     return errors::InvalidArgument(
42         "The results for recognition should contain ", labels_count_,
43         " elements, but there are ", latest_results.NumElements());
44   }
45 
46   if ((!previous_results_.empty()) &&
47       (current_time_ms < previous_results_.front().first)) {
48     return errors::InvalidArgument(
49         "Results must be fed in increasing time order, but received a "
50         "timestamp of ",
51         current_time_ms, " that was earlier than the previous one of ",
52         previous_results_.front().first);
53   }
54 
55   // Add the latest results to the head of the queue.
56   previous_results_.push_back({current_time_ms, latest_results});
57 
58   // Prune any earlier results that are too old for the averaging window.
59   const int64_t time_limit = current_time_ms - average_window_duration_ms_;
60   while (previous_results_.front().first < time_limit) {
61     previous_results_.pop_front();
62   }
63 
64   // If there are too few results, assume the result will be unreliable and
65   // bail.
66   const int64_t how_many_results = previous_results_.size();
67   const int64_t earliest_time = previous_results_.front().first;
68   const int64_t samples_duration = current_time_ms - earliest_time;
69   if ((how_many_results < minimum_count_) ||
70       (samples_duration < (average_window_duration_ms_ / 4))) {
71     *found_command = previous_top_label_;
72     *score = 0.0f;
73     *is_new_command = false;
74     return OkStatus();
75   }
76 
77   // Calculate the average score across all the results in the window.
78   std::vector<float> average_scores(labels_count_);
79   for (const auto& previous_result : previous_results_) {
80     const Tensor& scores_tensor = previous_result.second;
81     auto scores_flat = scores_tensor.flat<float>();
82     for (int i = 0; i < scores_flat.size(); ++i) {
83       average_scores[i] += scores_flat(i) / how_many_results;
84     }
85   }
86 
87   // Sort the averaged results in descending score order.
88   std::vector<std::pair<int, float>> sorted_average_scores;
89   sorted_average_scores.reserve(labels_count_);
90   for (int i = 0; i < labels_count_; ++i) {
91     sorted_average_scores.push_back(
92         std::pair<int, float>({i, average_scores[i]}));
93   }
94   std::sort(sorted_average_scores.begin(), sorted_average_scores.end(),
95             [](const std::pair<int, float>& left,
96                const std::pair<int, float>& right) {
97               return left.second > right.second;
98             });
99 
100   // See if the latest top score is enough to trigger a detection.
101   const int current_top_index = sorted_average_scores[0].first;
102   const string current_top_label = labels_[current_top_index];
103   const float current_top_score = sorted_average_scores[0].second;
104   // If we've recently had another label trigger, assume one that occurs too
105   // soon afterwards is a bad result.
106   int64_t time_since_last_top;
107   if ((previous_top_label_ == "_silence_") ||
108       (previous_top_label_time_ == std::numeric_limits<int64_t>::min())) {
109     time_since_last_top = std::numeric_limits<int64_t>::max();
110   } else {
111     time_since_last_top = current_time_ms - previous_top_label_time_;
112   }
113   if ((current_top_score > detection_threshold_) &&
114       (current_top_label != previous_top_label_) &&
115       (time_since_last_top > suppression_ms_)) {
116     previous_top_label_ = current_top_label;
117     previous_top_label_time_ = current_time_ms;
118     *is_new_command = true;
119   } else {
120     *is_new_command = false;
121   }
122   *found_command = current_top_label;
123   *score = current_top_score;
124 
125   return OkStatus();
126 }
127 
128 }  // namespace tensorflow
129