1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ 18 19 #include <list> 20 #include <sstream> 21 #include <string> 22 #include <thread> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" 28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" 29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/resource_mgr.h" 32 33 #if GOOGLE_CUDA 34 #if GOOGLE_TENSORRT 35 #include "tensorrt/include/NvInfer.h" 36 37 namespace tensorflow { 38 namespace tensorrt { 39 40 class SerializableResourceBase : public ResourceBase { 41 public: 42 virtual Status SerializeToString(string* serialized) = 0; 43 }; 44 45 class TRTCalibrationResource : public SerializableResourceBase { 46 public: 47 ~TRTCalibrationResource() override; 48 49 string DebugString() const override; 50 51 Status SerializeToString(string* serialized) override; 52 53 // Lookup table for temporary staging areas of input tensors for calibration. 54 std::unordered_map<string, std::pair<void*, size_t>> device_buffers_; 55 56 // Temporary staging areas for calibration inputs. 57 std::vector<PersistentTensor> device_tensors_; 58 59 std::unique_ptr<TRTInt8Calibrator> calibrator_; 60 TrtUniquePtrType<nvinfer1::IBuilder> builder_; 61 TrtUniquePtrType<nvinfer1::ICudaEngine> engine_; 62 std::unique_ptr<TRTBaseAllocator> allocator_; 63 Logger logger_; 64 // TODO(sami): Use threadpool threads! 65 std::unique_ptr<std::thread> thr_; 66 }; 67 68 } // namespace tensorrt 69 } // namespace tensorflow 70 71 #endif // GOOGLE_TENSORRT 72 #endif // GOOGLE_CUDA 73 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ 74