1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/framework/scope.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/device_set.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/grappler/clusters/cluster.h"
28 #include "tensorflow/core/grappler/grappler_item.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
33 #include "tensorflow/core/public/session.h"
34
35 #if GOOGLE_CUDA
36 #if GOOGLE_TENSORRT
37
38 namespace tensorflow {
39 namespace tensorrt {
40 namespace convert {
41
42 // TODO(laigd): put this into some test utils file.
ExpectStatus(Status status,error::Code code=error::OK,const char * substr=nullptr)43 void ExpectStatus(Status status, error::Code code = error::OK,
44 const char* substr = nullptr) {
45 EXPECT_EQ(code, status.code())
46 << status << " vs expected error code \"" << error::Code_Name(code)
47 << "\" and message \"" << substr << "\"";
48 if (substr) {
49 EXPECT_THAT(status.error_message(), ::testing::HasSubstr(substr)) << status;
50 }
51 }
52
53 class FakeCluster : public grappler::Cluster {
54 public:
FakeCluster()55 FakeCluster() : Cluster(0) {}
56
SetDeviceSet(const DeviceSet * device_set)57 void SetDeviceSet(const DeviceSet* device_set) { device_set_ = device_set; }
58
GetDeviceSet() const59 const DeviceSet* GetDeviceSet() const override { return device_set_; }
60
type() const61 string type() const override { return ""; }
Provision()62 Status Provision() override { return Status::OK(); }
Initialize(const grappler::GrapplerItem & item)63 Status Initialize(const grappler::GrapplerItem& item) override {
64 return Status::OK();
65 }
Run(const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)66 Status Run(const GraphDef& graph_def,
67 const std::vector<std::pair<string, Tensor>>& feed,
68 const std::vector<string>& fetch, RunMetadata* metadata) override {
69 return Status::OK();
70 }
71
72 private:
73 const DeviceSet* device_set_;
74 };
75
TEST(ConvertGraphTest,GetDeviceAndAllocator)76 TEST(ConvertGraphTest, GetDeviceAndAllocator) {
77 ConversionParams params;
78 EngineInfo engine_info;
79 {
80 // params.cluster is not set, and no gpu device is available.
81 auto result = GetDeviceAndAllocator(params, engine_info);
82 EXPECT_EQ(-1, result.first);
83 EXPECT_EQ(nullptr, result.second);
84 }
85
86 // Create a session with two (virtual) gpu device.
87 SessionOptions options;
88 ConfigProto* config = &options.config;
89 GPUOptions* gpu_options = config->mutable_gpu_options();
90 auto virtual_devices =
91 gpu_options->mutable_experimental()->add_virtual_devices();
92 virtual_devices->add_memory_limit_mb(200);
93 virtual_devices->add_memory_limit_mb(200);
94 std::unique_ptr<Session> session(NewSession(options));
95
96 {
97 // params.cluster is not set, should find and return first gpu id and
98 // corresponding allocator.
99 auto result = GetDeviceAndAllocator(params, engine_info);
100 EXPECT_EQ(0, result.first);
101 EXPECT_NE(nullptr, result.second);
102 EXPECT_EQ("GPU_0_bfc", result.second->Name());
103 }
104
105 FakeCluster cluster;
106 params.cluster = &cluster;
107 {
108 // params.cluster->GetDeviceSet() returns null, should find and return first
109 // gpu id and corresponding allocator.
110 auto result = GetDeviceAndAllocator(params, engine_info);
111 EXPECT_EQ(0, result.first);
112 EXPECT_NE(nullptr, result.second);
113 EXPECT_EQ("GPU_0_bfc", result.second->Name());
114 }
115
116 // Build the DeviceSet.
117 DeviceSet device_set;
118 const DeviceMgr* device_mgr = nullptr;
119 TF_ASSERT_OK(session->LocalDeviceManager(&device_mgr));
120 for (auto d : device_mgr->ListDevices()) {
121 device_set.AddDevice(d);
122 }
123 cluster.SetDeviceSet(&device_set);
124 {
125 // engine_info.device is not set, should find and return first gpu id and
126 // corresponding allocator.
127 auto result = GetDeviceAndAllocator(params, engine_info);
128 EXPECT_EQ(0, result.first);
129 EXPECT_NE(nullptr, result.second);
130 EXPECT_EQ("GPU_0_bfc", result.second->Name());
131 }
132
133 engine_info.device = "/GPU:1";
134 {
135 // Set to use second device.
136 auto result = GetDeviceAndAllocator(params, engine_info);
137 EXPECT_EQ(0, result.first);
138 EXPECT_NE(nullptr, result.second);
139 EXPECT_EQ("GPU_1_bfc", result.second->Name());
140 }
141
142 engine_info.device = "/GPU:3";
143 {
144 // Set to use nonexistent device.
145 auto result = GetDeviceAndAllocator(params, engine_info);
146 EXPECT_EQ(-1, result.first);
147 EXPECT_EQ(nullptr, result.second);
148 }
149 }
150
151 class ConvertAfterShapesTest : public ::testing::Test {
152 public:
RunConvertAfterShape(Scope s,GraphDef * output_graph_def)153 Status RunConvertAfterShape(Scope s, GraphDef* output_graph_def) {
154 // Create GraphProperties.
155 grappler::GrapplerItem item;
156 TF_EXPECT_OK(s.ToGraphDef(&item.graph));
157 grappler::GraphProperties graph_properties(item);
158 TF_EXPECT_OK(graph_properties.InferStatically(true));
159
160 // Construct ConversionParams.
161 const std::vector<string> output_names{"output"};
162 ConversionParams params;
163 params.input_graph_def = &item.graph;
164 params.output_names = &output_names;
165 params.max_workspace_size_bytes = 8 << 20;
166 params.output_graph_def = output_graph_def;
167 params.minimum_segment_size = 1;
168 params.graph_properties = &graph_properties;
169 params.use_calibration = false;
170 params.trt_logger_name = "DefaultLogger";
171
172 return ConvertAfterShapes(params);
173 }
174 };
175
TEST_F(ConvertAfterShapesTest,DirectlyConnectedEngines)176 TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) {
177 // Create the graph. There will be two TRTEngineOps after the conversion, and
178 // the upstream TRTEngineOp will have two output connections from the same
179 // node:port inside the op to the downstream TRTEngineOp. Then, if it adds the
180 // downstream TRTEngineOp first, when adding the upstream op it'll need to
181 // update the same output connection twice. This test ensures the correctness
182 // of the conversion under such condition.
183 Scope s = Scope::NewRootScope();
184 auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
185 ops::Placeholder::Shape({2, 1}));
186 // We purposefully choose the name of the root node of each segment, so it'll
187 // process the segment in the downstream first, then, when it tries to update
188 // the edge between the two TRTEngineOps, it'll try to add the same edge
189 // multiple times.
190 auto segment_root_1 = ops::Identity(s.WithOpName("segment_root_b"), input);
191 auto add1 = ops::Add(s.WithOpName("add1"), segment_root_1, segment_root_1);
192 // Add incompatible reshapes that change the batch dimension.
193 auto incompatible =
194 ops::Reshape(s.WithOpName("reshape1"), add1, Input({1, 2}));
195 incompatible =
196 ops::Reshape(s.WithOpName("reshape2"), incompatible, Input({2, 1}));
197
198 auto add2 = ops::Add(s.WithOpName("add2"), incompatible, add1);
199 auto segment_root_2 = ops::Identity(s.WithOpName("segment_root_a"), add1);
200 auto add3 = ops::Add(s.WithOpName("add3"), add2, segment_root_2);
201 ops::Identity(s.WithOpName("output"), add3);
202
203 GraphDef output_graph_def;
204 TF_EXPECT_OK(RunConvertAfterShape(s, &output_graph_def));
205
206 int num_trt_ops = 0;
207 for (const NodeDef& node : output_graph_def.node()) {
208 if (node.name() == "TRTEngineOp_1") {
209 EXPECT_EQ(1, node.input_size());
210 EXPECT_EQ("input", node.input(0));
211 ++num_trt_ops;
212 } else if (node.name() == "TRTEngineOp_0") {
213 EXPECT_EQ(2, node.input_size());
214 EXPECT_EQ("TRTEngineOp_1", node.input(0));
215 EXPECT_EQ("reshape2", node.input(1));
216 ++num_trt_ops;
217 }
218 }
219 EXPECT_EQ(2, num_trt_ops);
220 }
221
222 } // namespace convert
223 } // namespace tensorrt
224 } // namespace tensorflow
225
226 #endif // GOOGLE_TENSORRT
227 #endif // GOOGLE_CUDA
228