• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
17 
18 #include <atomic>
19 #include <unordered_map>
20 
21 #include "tensorflow/core/platform/logging.h"
22 
23 #if GOOGLE_CUDA
24 #if GOOGLE_TENSORRT
25 #include "cuda/include/cuda_runtime_api.h"
26 
27 namespace tensorflow {
28 namespace tensorrt {
29 
30 // set the batch size before constructing the thread to execute engine
getBatchSize() const31 int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
32 
TRTInt8Calibrator(const std::unordered_map<string,std::pair<void *,size_t>> & dev_buffers,int batch_size,string engine_name)33 TRTInt8Calibrator::TRTInt8Calibrator(
34     const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
35     int batch_size, string engine_name)
36     : batch_size_(batch_size),
37       done_(false),
38       dev_buffers_(dev_buffers),
39       // Make sure setBatch() waits until getBatch() is called (the first time).
40       calib_running_(true),
41       batch_is_set_(false),
42       engine_name_(engine_name) {}
43 
TRTInt8Calibrator(const string & calib_data)44 TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
45     : batch_size_(0),
46       done_(true),
47       calib_running_(false),
48       batch_is_set_(false),
49       calibration_table_(calib_data) {}
50 
setBatch(const std::unordered_map<string,void * > & data,const cudaStream_t stream)51 bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
52                                  const cudaStream_t stream) {
53   mutex_lock lock(cond_mtx_);
54 
55   // Wait while the queue is full or calibration is running.
56   while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
57   if (done_) return false;
58   CHECK(!calib_running_ && !batch_is_set_);
59   VLOG(1) << "Set Batch Waiting finished";
60 
61   // Sets the batch.
62   for (const auto it : data) {
63     auto devptr = dev_buffers_.find(it.first);
64     if (devptr == dev_buffers_.end()) {
65       LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first
66                  << "' does not match with the buffer names";
67     }
68     const auto& d = devptr->second;
69 
70     // TODO(sami,aaroey): Need to figure out a way to ensure synchronization
71     // between stream, perhaps using a tensor?
72     auto status = cudaMemcpyAsync(d.first, it.second, d.second,
73                                   cudaMemcpyDeviceToDevice, stream);
74     if (status != cudaSuccess) {
75       LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first
76                  << "' failed with " << status;
77     }
78   }
79 
80   // TODO(Sami, aaorey): Find an alternative way!
81   // we have to wait for the stream before returning!
82   cudaStreamSynchronize(stream);
83   batch_is_set_ = true;
84   cond_.notify_all();
85   return true;
86 }
87 
getBatch(void ** bindings,const char ** names,int num_bindings)88 bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
89                                  int num_bindings) {
90   mutex_lock lock(cond_mtx_);
91   // Notify finish of last round of calibration.
92   calib_running_ = false;
93   cond_.notify_all();
94 
95   // Wait until new batch arrives
96   while ((!batch_is_set_ && !done_)) cond_.wait(lock);
97   if (done_) return false;
98 
99   // Gets the batch
100   for (int i = 0; i < num_bindings; i++) {
101     auto it = dev_buffers_.find(names[i]);
102     if (it == dev_buffers_.end()) {
103       LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
104                  << names[i] << "' at position " << i;
105     }
106     bindings[i] = it->second.first;
107   }
108   batch_is_set_ = false;
109   calib_running_ = true;
110   return true;
111 }
112 
waitAndSetDone()113 void TRTInt8Calibrator::waitAndSetDone() {
114   mutex_lock lock(cond_mtx_);
115   // Wait while the queue is full or calibration is running, so we don't miss
116   // the last batch.
117   while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
118   if (!done_) {
119     done_ = true;
120     cond_.notify_all();
121   }
122 }
123 
readCalibrationCache(std::size_t & length)124 const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
125   if (calibration_table_.empty()) return nullptr;
126   length = calibration_table_.size();
127   return calibration_table_.data();
128 }
129 
setDone()130 void TRTInt8Calibrator::setDone() {
131   mutex_lock lock(cond_mtx_);
132   done_ = true;
133   cond_.notify_all();
134 }
135 
writeCalibrationCache(const void * ptr,std::size_t length)136 void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
137                                               std::size_t length) {
138   calibration_table_ = string(static_cast<const char*>(ptr), length);
139   VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
140           << " length=" << length;
141 }
~TRTInt8Calibrator()142 TRTInt8Calibrator::~TRTInt8Calibrator() {
143   VLOG(1) << "Destroying calibrator for " << engine_name_;
144 }
145 
146 }  // namespace tensorrt
147 }  // namespace tensorflow
148 
149 #endif
150 #endif
151