1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17
18 #include <regex> // NOLINT
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/device_set.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/grappler/clusters/cluster.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
36 #include "tensorflow/core/public/session.h"
37
38 #if GOOGLE_CUDA && GOOGLE_TENSORRT
39
40 namespace tensorflow {
41 namespace tensorrt {
42 namespace convert {
43
44 class FakeCluster : public grappler::Cluster {
45 public:
FakeCluster()46 FakeCluster() : Cluster(0) {}
47
SetDeviceSet(const DeviceSet * device_set)48 void SetDeviceSet(const DeviceSet* device_set) { device_set_ = device_set; }
49
GetDeviceSet() const50 const DeviceSet* GetDeviceSet() const override { return device_set_; }
51
type() const52 string type() const override { return ""; }
Provision()53 Status Provision() override { return Status::OK(); }
Initialize(const grappler::GrapplerItem & item)54 Status Initialize(const grappler::GrapplerItem& item) override {
55 return Status::OK();
56 }
Run(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)57 Status Run(const GraphDef& graph_def,
58 const std::vector<std::pair<string, Tensor>>& feed,
59 const std::vector<string>& fetch, RunMetadata* metadata) override {
60 return Status::OK();
61 }
62
63 private:
64 const DeviceSet* device_set_ = nullptr;
65 };
66
TEST(GetDeviceAndAllocatorTest,GetDeviceAndAllocator)67 TEST(GetDeviceAndAllocatorTest, GetDeviceAndAllocator) {
68 TRTOptimizationPass::ConversionParams params;
69 EngineInfo engine_info;
70 {
71 // cluster is not set, and no gpu device is available.
72 auto result = GetDeviceAndAllocator(nullptr, engine_info);
73 EXPECT_EQ(-1, result.first);
74 EXPECT_EQ(nullptr, result.second);
75 }
76
77 // Create a session with two (virtual) gpu device.
78 SessionOptions options;
79 ConfigProto* config = &options.config;
80 GPUOptions* gpu_options = config->mutable_gpu_options();
81 auto virtual_devices =
82 gpu_options->mutable_experimental()->add_virtual_devices();
83 virtual_devices->add_memory_limit_mb(200);
84 virtual_devices->add_memory_limit_mb(200);
85 std::unique_ptr<Session> session(NewSession(options));
86
87 {
88 // cluster is not set, should find and return first gpu id and
89 // corresponding allocator.
90 auto result = GetDeviceAndAllocator(nullptr, engine_info);
91 EXPECT_EQ(0, result.first);
92 EXPECT_NE(nullptr, result.second);
93 EXPECT_EQ("GPU_0_bfc", result.second->Name());
94 }
95
96 FakeCluster cluster;
97 {
98 // params.cluster->GetDeviceSet() returns null, should find and return first
99 // gpu id and corresponding allocator.
100 auto result = GetDeviceAndAllocator(&cluster, engine_info);
101 EXPECT_EQ(0, result.first);
102 EXPECT_NE(nullptr, result.second);
103 EXPECT_EQ("GPU_0_bfc", result.second->Name());
104 }
105
106 // Build the DeviceSet.
107 DeviceSet device_set;
108 const DeviceMgr* device_mgr = nullptr;
109 TF_ASSERT_OK(session->LocalDeviceManager(&device_mgr));
110 for (auto d : device_mgr->ListDevices()) {
111 device_set.AddDevice(d);
112 }
113 cluster.SetDeviceSet(&device_set);
114 {
115 // engine_info.device is not set, should find and return first gpu id and
116 // corresponding allocator.
117 auto result = GetDeviceAndAllocator(&cluster, engine_info);
118 EXPECT_EQ(0, result.first);
119 EXPECT_NE(nullptr, result.second);
120 EXPECT_EQ("GPU_0_bfc", result.second->Name());
121 }
122
123 engine_info.device = "/GPU:1";
124 {
125 // Set to use second device.
126 auto result = GetDeviceAndAllocator(&cluster, engine_info);
127 EXPECT_EQ(0, result.first);
128 EXPECT_NE(nullptr, result.second);
129 EXPECT_EQ("GPU_1_bfc", result.second->Name());
130 }
131
132 engine_info.device = "/GPU:3";
133 {
134 // Set to use nonexistent device.
135 auto result = GetDeviceAndAllocator(&cluster, engine_info);
136 EXPECT_EQ(-1, result.first);
137 EXPECT_EQ(nullptr, result.second);
138 }
139 }
140
141 class ConvertGraphTest : public ::testing::Test {
142 public:
RunConvertGraph(Scope s,GraphDef * output_graph_def,int maximum_batch_size=1000)143 Status RunConvertGraph(Scope s, GraphDef* output_graph_def,
144 int maximum_batch_size = 1000) {
145 // Create GraphProperties.
146 grappler::GrapplerItem item;
147 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
148 grappler::GraphProperties graph_properties(item);
149 TF_EXPECT_OK(graph_properties.InferStatically(true));
150
151 // Construct ConversionParams.
152 const std::vector<string> input_output_names{"output"};
153 TRTOptimizationPass::ConversionParams params;
154 params.max_batch_size = maximum_batch_size;
155 params.max_workspace_size_bytes = 8 << 20;
156 params.minimum_segment_size = 1;
157 params.use_calibration = false;
158 params.trt_logger_name = "DefaultLogger";
159 return ConvertGraph(params, item, input_output_names, nullptr,
160 output_graph_def);
161 }
162 };
163
TEST_F(ConvertGraphTest,DirectlyConnectedEngines)164 TEST_F(ConvertGraphTest, DirectlyConnectedEngines) {
165 // Create the graph. There will be two TRTEngineOps after the conversion, and
166 // the upstream TRTEngineOp will have two output connections from the same
167 // node:port inside the op to the downstream TRTEngineOp. Then, if it adds the
168 // downstream TRTEngineOp first, when adding the upstream op it'll need to
169 // update the same output connection twice. This test ensures the correctness
170 // of the conversion under such condition.
171 Scope s = Scope::NewRootScope();
172 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
173 ops::Placeholder::Shape({2, 1}));
174 // We purposefully choose the name of the root node of each segment, so it'll
175 // process the segment in the downstream first, then, when it tries to update
176 // the edge between the two TRTEngineOps, it'll try to add the same edge
177 // multiple times.
178 auto segment_root_1 = ops::Identity(s.WithOpName("segment_root_b"), input);
179 auto add1 = ops::Add(s.WithOpName("add1"), segment_root_1, segment_root_1);
180 // Add incompatible reshapes that change the batch dimension.
181 auto incompatible =
182 ops::Reshape(s.WithOpName("reshape1"), add1, Input({1, 2}));
183 incompatible =
184 ops::Reshape(s.WithOpName("reshape2"), incompatible, Input({2, 1}));
185
186 auto add2 = ops::Add(s.WithOpName("add2"), incompatible, add1);
187 auto segment_root_2 = ops::Identity(s.WithOpName("segment_root_a"), add1);
188 auto add3 = ops::Add(s.WithOpName("add3"), add2, segment_root_2);
189 ops::Identity(s.WithOpName("output"), add3);
190
191 GraphDef output_graph_def;
192 TF_EXPECT_OK(RunConvertGraph(s, &output_graph_def));
193
194 auto remove_graph_sequence_number = [](std::string node_name) {
195 const std::regex pattern("TRTEngineOp_[0-9]+_");
196 return std::regex_replace(node_name, pattern, "TRTEngineOp_");
197 };
198 int num_trt_ops = 0;
199 for (const NodeDef& node : output_graph_def.node()) {
200 std::string node_name = node.name();
201 if (node.op() != "TRTEngineOp") continue;
202 node_name = remove_graph_sequence_number(node_name);
203 if (node_name == "TRTEngineOp_001") {
204 EXPECT_EQ(1, node.input_size());
205 EXPECT_EQ("input", node.input(0));
206 ++num_trt_ops;
207 } else if (node_name == "TRTEngineOp_000") {
208 EXPECT_EQ(2, node.input_size());
209 EXPECT_EQ("TRTEngineOp_001", remove_graph_sequence_number(node.input(0)));
210 EXPECT_EQ("reshape2", node.input(1));
211 ++num_trt_ops;
212 }
213 }
214 EXPECT_EQ(2, num_trt_ops);
215 }
216
217 } // namespace convert
218 } // namespace tensorrt
219 } // namespace tensorflow
220
221 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
222