• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h"
16 
17 #include <algorithm>
18 #include <string>
19 
20 namespace tensorflow {
21 namespace calibrator {
22 
23 ABSL_CONST_INIT absl::Mutex CalibratorSingleton::lock_(absl::kConstInit);
24 
GetInstance()25 CalibratorSingleton& CalibratorSingleton::GetInstance() {
26   static CalibratorSingleton* calibrator = new CalibratorSingleton();
27   return *calibrator;
28 }
29 
ClearCollectedInformation()30 void CalibratorSingleton::ClearCollectedInformation() {
31   absl::MutexLock lock(&lock_);
32 
33   CalibratorSingleton& instance = GetInstance();
34   instance.id_to_min_.clear();
35   instance.id_to_max_.clear();
36 }
37 
ClearData(absl::string_view id)38 void CalibratorSingleton::ClearData(absl::string_view id) {
39   absl::MutexLock lock(&lock_);
40 
41   CalibratorSingleton& instance = GetInstance();
42 
43   const std::string id_str{id};
44   instance.id_to_min_.erase(id_str);
45   instance.id_to_max_.erase(id_str);
46 }
47 
ReportMinMax(absl::string_view id,const float min_val,const float max_val)48 void CalibratorSingleton::ReportMinMax(absl::string_view id,
49                                        const float min_val,
50                                        const float max_val) {
51   absl::MutexLock lock(&lock_);
52 
53   CalibratorSingleton& instance = GetInstance();
54 
55   const std::string id_str{id};
56 
57   // Update the min value.
58   if (auto min_itr = instance.id_to_min_.find(id_str);
59       min_itr != instance.id_to_min_.end()) {
60     min_itr->second = std::min(min_val, min_itr->second);
61   } else {
62     instance.id_to_min_[id_str] = min_val;
63   }
64 
65   // Update the max values.
66   if (auto max_itr = instance.id_to_max_.find(id_str);
67       max_itr != instance.id_to_max_.end()) {
68     max_itr->second = std::max(max_val, max_itr->second);
69   } else {
70     instance.id_to_max_[id_str] = max_val;
71   }
72 }
73 
GetMinMax(absl::string_view id)74 std::optional<std::pair<float, float>> CalibratorSingleton::GetMinMax(
75     absl::string_view id) {
76   absl::MutexLock lock(&lock_);
77 
78   CalibratorSingleton& instance = GetInstance();
79 
80   const std::string id_str{id};
81   const auto min_itr = instance.id_to_min_.find(id_str);
82   const auto max_itr = instance.id_to_max_.find(id_str);
83   if (min_itr == instance.id_to_min_.end() ||
84       max_itr == instance.id_to_max_.end()) {
85     return std::nullopt;
86   }
87 
88   return std::pair<float, float>{min_itr->second, max_itr->second};
89 }
90 
91 }  // namespace calibrator
92 }  // namespace tensorflow
93