1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <algorithm> 16 #include <memory> 17 #include <vector> 18 19 #include "absl/memory/memory.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" 22 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT 23 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" 24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/resource_mgr.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/refcount.h" 30 #include "tensorflow/core/lib/io/record_reader.h" 31 #include "tensorflow/core/lib/io/record_writer.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/thread_annotations.h" 35 36 #if GOOGLE_CUDA && GOOGLE_TENSORRT 37 #include "third_party/tensorrt/NvInfer.h" 38 39 namespace tensorflow { 40 namespace tensorrt { 41 using ::nvinfer1::IRuntime; 42 43 class CreateTRTResourceHandle : public OpKernel { 44 public: CreateTRTResourceHandle(OpKernelConstruction * ctx)45 explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) { 46 OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_)); 47 } 48 Compute(OpKernelContext * ctx)49 void Compute(OpKernelContext* ctx) override { 50 { 51 mutex_lock l(mutex_); 52 if (!initialized_) { 53 AllocatorAttributes attr; 54 attr.set_on_host(true); 55 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), 56 &handle_, attr)); 57 58 VLOG(1) << "Creating TRT engine cache resource handle for op " 59 << resource_name_ << " on device " << ctx->device()->name(); 60 handle_.scalar<ResourceHandle>()() = 61 MakeResourceHandle<TRTEngineCacheResource>( 62 ctx, std::string(kTfTrtContainerName), resource_name_); 63 initialized_ = true; 64 } 65 } 66 ctx->set_output(0, handle_); 67 } 68 69 private: 70 string resource_name_; 71 Tensor handle_; 72 mutex mutex_; 73 bool initialized_ TF_GUARDED_BY(mutex_) = false; 74 75 TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle); 76 }; 77 78 REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle") 79 .Device(DEVICE_GPU) 80 .HostMemory("resource_handle"), 81 CreateTRTResourceHandle); 82 83 class InitializeTRTResource : public OpKernel { 84 public: InitializeTRTResource(OpKernelConstruction * ctx)85 explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) { 86 OP_REQUIRES_OK( 87 ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_)); 88 } 89 Compute(OpKernelContext * ctx)90 void Compute(OpKernelContext* ctx) override { 91 ResourceHandle handle = HandleFromInput(ctx, 0); 92 core::RefCountPtr<TRTEngineCacheResource> resource; 93 OP_REQUIRES_OK( 94 ctx, LookupOrCreateResource<TRTEngineCacheResource>( 95 ctx, handle, &resource, 96 [this, ctx](TRTEngineCacheResource** resource) -> Status { 97 *resource = new TRTEngineCacheResource( 98 ctx, this->max_cached_engines_); 99 return Status::OK(); 100 })); 101 102 auto allocator = resource->allocator_.get(); 103 OP_REQUIRES(ctx, allocator != nullptr, 104 errors::Internal("Not able to initialize TRT engine cache when " 105 "GPU allocator is empty.")); 106 OP_REQUIRES(ctx, resource->cache_.size() == 0, 107 errors::Internal("Expect engine cache to be empty, but got ", 108 resource->cache_.size(), " entries.")); 109 110 // Get the file name. 111 const string& filename = ctx->input(1).scalar<tstring>()(); 112 OP_REQUIRES(ctx, !filename.empty(), 113 errors::InvalidArgument("filename cannot be empty.")); 114 115 // Parse the serialized engines and add them to the cache. 116 std::unique_ptr<RandomAccessFile> file; 117 OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file)); 118 auto reader = std::make_unique<io::RecordReader>(file.get()); 119 120 uint64 offset = 0; 121 int num_loaded_engine = 0; 122 do { 123 tstring record; 124 Status status = reader->ReadRecord(&offset, &record); 125 if (errors::IsOutOfRange(status)) break; 126 127 TRTEngineInstance engine_instance; 128 engine_instance.ParseFromString(record); 129 std::vector<TensorShape> engine_input_shapes; 130 const auto& input_shapes = engine_instance.input_shapes(); 131 engine_input_shapes.reserve(input_shapes.size()); 132 for (const TensorShapeProto& shape : input_shapes) { 133 engine_input_shapes.emplace_back(shape); 134 } 135 136 TrtUniquePtrType<IRuntime> infer( 137 nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger())); 138 infer->setGpuAllocator(allocator); 139 TrtUniquePtrType<nvinfer1::ICudaEngine> engine( 140 infer->deserializeCudaEngine( 141 engine_instance.serialized_engine().c_str(), 142 engine_instance.serialized_engine().size(), nullptr)); 143 auto raw_engine = engine.get(); 144 std::vector<ExecutionContext> ctx_vec; 145 if (num_loaded_engine == 0) { 146 // Restore profiles if there are any. Currently only 1 engine is allowed 147 // in dynamic mode therefore we call this only for the 0th engine. 148 // it is a no-op in implicit batch mode. 149 OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles( 150 raw_engine, engine_input_shapes.size())); 151 OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts( 152 raw_engine, &ctx_vec)); 153 } else { 154 // Multiple engines are only available in static mode. For each engine 155 // we have only a single execution context. 156 ctx_vec.push_back(ExecutionContext::Create(raw_engine)); 157 } 158 resource->cache_.emplace(engine_input_shapes, 159 std::make_unique<EngineContext>( 160 std::move(engine), std::move(ctx_vec))); 161 ++num_loaded_engine; 162 } while (1); 163 VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op " 164 << handle.name() << " on device " << ctx->device()->name() 165 << " from file " << filename; 166 } 167 168 private: 169 // Maximum number of cached engines 170 int max_cached_engines_; 171 172 TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource); 173 }; 174 175 REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource") 176 .Device(DEVICE_GPU) 177 .HostMemory("resource_handle"), 178 InitializeTRTResource); 179 180 class SerializeTRTResource : public OpKernel { 181 public: SerializeTRTResource(OpKernelConstruction * ctx)182 explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) { 183 OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_)); 184 OP_REQUIRES_OK(ctx, ctx->GetAttr("save_gpu_specific_engines", 185 &save_gpu_specific_engines_)); 186 } 187 Compute(OpKernelContext * ctx)188 void Compute(OpKernelContext* ctx) override { 189 const string& resource_name = ctx->input(0).scalar<tstring>()(); 190 const string& filename = ctx->input(1).scalar<tstring>()(); 191 OP_REQUIRES(ctx, !filename.empty(), 192 errors::InvalidArgument("filename cannot be empty.")); 193 194 // Lookup engine cache resource. 195 TRTEngineCacheResource* resource = nullptr; 196 OP_REQUIRES( 197 ctx, 198 ctx->resource_manager() 199 ->Lookup(std::string(kTfTrtContainerName), resource_name, &resource) 200 .ok(), 201 errors::NotFound("TRTEngineCacheResource not yet created")); 202 core::ScopedUnref unref_me(resource); 203 204 // Terminate the calibration if any. 205 if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration(); 206 207 // Serialize the engines and write them to file. 208 std::unique_ptr<WritableFile> file; 209 OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file)); 210 auto writer = std::make_unique<io::RecordWriter>(file.get()); 211 212 int num_serialized_engines = 0; 213 if (save_gpu_specific_engines_) { 214 // If user requests TRT engines export, recursively create 215 // requisite directories. 216 const char* export_trt_engines_env = 217 getenv("TF_TRT_EXPORT_TRT_ENGINES_PATH"); 218 if (export_trt_engines_env) { 219 VLOG(1) << "Exporting TRT engines to directory: " 220 << export_trt_engines_env; 221 OP_REQUIRES_OK( 222 ctx, ctx->env()->RecursivelyCreateDir(export_trt_engines_env)); 223 } 224 225 for (const auto& pair : resource->cache_) { 226 // Ignore engines that failed to build. 227 const std::unique_ptr<EngineContext>& engine = pair.second; 228 if (!engine || !engine->GetCudaEngine()) continue; 229 230 TRTEngineInstance engine_instance; 231 // Add input shapes. 232 const std::vector<TensorShape>& engine_input_shapes = pair.first; 233 for (const TensorShape& shape : engine_input_shapes) { 234 shape.AsProto(engine_instance.add_input_shapes()); 235 } 236 // Add the serialized engine. 237 TrtUniquePtrType<nvinfer1::IHostMemory> engine_data( 238 engine->GetCudaEngine()->serialize()); 239 engine_instance.set_serialized_engine(engine_data->data(), 240 engine_data->size()); 241 242 if (export_trt_engines_env) { 243 const std::string engine_filename = 244 std::string(export_trt_engines_env) + "/" + resource_name; 245 std::unique_ptr<WritableFile> engine_file; 246 OP_REQUIRES_OK( 247 ctx, ctx->env()->NewWritableFile(engine_filename, &engine_file)); 248 OP_REQUIRES_OK(ctx, engine_file->Append(StringPiece( 249 static_cast<char*>(engine_data->data()), 250 engine_data->size()))); 251 252 const std::string dims_filename = 253 std::string(export_trt_engines_env) + "/dims-" + resource_name; 254 std::unique_ptr<WritableFile> dims_file; 255 OP_REQUIRES_OK( 256 ctx, ctx->env()->NewWritableFile(dims_filename, &dims_file)); 257 258 for (const TensorShape& shape : engine_input_shapes) { 259 OP_REQUIRES_OK(ctx, 260 dims_file->Append(StringPiece(shape.DebugString()))); 261 } 262 } 263 264 OP_REQUIRES_OK( 265 ctx, writer->WriteRecord(engine_instance.SerializeAsString())); 266 ++num_serialized_engines; 267 } 268 } else { 269 VLOG(1) << "TRT Engines are not serialized for op: " << resource_name; 270 } 271 VLOG(1) << "Serialized " << num_serialized_engines << " TRT engines for op " 272 << resource_name << " on device " << ctx->device()->name() 273 << " to file " << filename; 274 275 if (delete_resource_) { 276 VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name 277 << " on device " << ctx->device()->name(); 278 OP_REQUIRES_OK(ctx, 279 ctx->resource_manager()->Delete<TRTEngineCacheResource>( 280 std::string(kTfTrtContainerName), resource_name)); 281 } 282 } 283 284 private: 285 bool delete_resource_ = false; 286 bool save_gpu_specific_engines_ = true; 287 288 TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource); 289 }; 290 291 REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU), 292 SerializeTRTResource); 293 294 } // namespace tensorrt 295 } // namespace tensorflow 296 297 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 298