1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/websockets/websocket_inflater.h"
11
12 #include <string>
13 #include <vector>
14
15 #include "net/base/io_buffer.h"
16 #include "net/websockets/websocket_deflater.h"
17 #include "net/websockets/websocket_test_util.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19
20 namespace net {
21
22 namespace {
23
ToString(IOBufferWithSize * buffer)24 std::string ToString(IOBufferWithSize* buffer) {
25 return std::string(buffer->data(), buffer->size());
26 }
27
TEST(WebSocketInflaterTest,Construct)28 TEST(WebSocketInflaterTest, Construct) {
29 WebSocketInflater inflater;
30 ASSERT_TRUE(inflater.Initialize(15));
31
32 EXPECT_EQ(0u, inflater.CurrentOutputSize());
33 }
34
TEST(WebSocketInflaterTest,InflateHelloTakeOverContext)35 TEST(WebSocketInflaterTest, InflateHelloTakeOverContext) {
36 WebSocketInflater inflater;
37 ASSERT_TRUE(inflater.Initialize(15));
38 scoped_refptr<IOBufferWithSize> actual1, actual2;
39
40 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
41 ASSERT_TRUE(inflater.Finish());
42 actual1 = inflater.GetOutput(inflater.CurrentOutputSize());
43 ASSERT_TRUE(actual1.get());
44 EXPECT_EQ("Hello", ToString(actual1.get()));
45 EXPECT_EQ(0u, inflater.CurrentOutputSize());
46
47 ASSERT_TRUE(inflater.AddBytes("\xf2\x00\x11\x00\x00", 5));
48 ASSERT_TRUE(inflater.Finish());
49 actual2 = inflater.GetOutput(inflater.CurrentOutputSize());
50 ASSERT_TRUE(actual2.get());
51 EXPECT_EQ("Hello", ToString(actual2.get()));
52 EXPECT_EQ(0u, inflater.CurrentOutputSize());
53 }
54
TEST(WebSocketInflaterTest,InflateHelloSmallCapacity)55 TEST(WebSocketInflaterTest, InflateHelloSmallCapacity) {
56 WebSocketInflater inflater(1, 1);
57 ASSERT_TRUE(inflater.Initialize(15));
58 std::string actual;
59
60 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
61 ASSERT_TRUE(inflater.Finish());
62 for (size_t i = 0; i < 5; ++i) {
63 ASSERT_EQ(1u, inflater.CurrentOutputSize());
64 scoped_refptr<IOBufferWithSize> buffer = inflater.GetOutput(1);
65 ASSERT_TRUE(buffer.get());
66 ASSERT_EQ(1, buffer->size());
67 actual += ToString(buffer.get());
68 }
69 EXPECT_EQ("Hello", actual);
70 EXPECT_EQ(0u, inflater.CurrentOutputSize());
71 }
72
TEST(WebSocketInflaterTest,InflateHelloSmallCapacityGetTotalOutput)73 TEST(WebSocketInflaterTest, InflateHelloSmallCapacityGetTotalOutput) {
74 WebSocketInflater inflater(1, 1);
75 ASSERT_TRUE(inflater.Initialize(15));
76 scoped_refptr<IOBufferWithSize> actual;
77
78 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
79 ASSERT_TRUE(inflater.Finish());
80 ASSERT_EQ(1u, inflater.CurrentOutputSize());
81 actual = inflater.GetOutput(1024);
82 EXPECT_EQ("Hello", ToString(actual.get()));
83 EXPECT_EQ(0u, inflater.CurrentOutputSize());
84 }
85
TEST(WebSocketInflaterTest,InflateInvalidData)86 TEST(WebSocketInflaterTest, InflateInvalidData) {
87 WebSocketInflater inflater;
88 ASSERT_TRUE(inflater.Initialize(15));
89 EXPECT_FALSE(inflater.AddBytes("\xf2\x48\xcd\xc9INVALID DATA", 16));
90 }
91
TEST(WebSocketInflaterTest,ChokedInvalidData)92 TEST(WebSocketInflaterTest, ChokedInvalidData) {
93 WebSocketInflater inflater(1, 1);
94 ASSERT_TRUE(inflater.Initialize(15));
95
96 EXPECT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9INVALID DATA", 16));
97 EXPECT_TRUE(inflater.Finish());
98 EXPECT_EQ(1u, inflater.CurrentOutputSize());
99 EXPECT_FALSE(inflater.GetOutput(1024).get());
100 }
101
TEST(WebSocketInflaterTest,MultipleAddBytesCalls)102 TEST(WebSocketInflaterTest, MultipleAddBytesCalls) {
103 WebSocketInflater inflater;
104 ASSERT_TRUE(inflater.Initialize(15));
105 std::string input("\xf2\x48\xcd\xc9\xc9\x07\x00", 7);
106 scoped_refptr<IOBufferWithSize> actual;
107
108 for (char& c : input) {
109 ASSERT_TRUE(inflater.AddBytes(&c, 1));
110 }
111 ASSERT_TRUE(inflater.Finish());
112 actual = inflater.GetOutput(5);
113 ASSERT_TRUE(actual.get());
114 EXPECT_EQ("Hello", ToString(actual.get()));
115 }
116
TEST(WebSocketInflaterTest,Reset)117 TEST(WebSocketInflaterTest, Reset) {
118 WebSocketInflater inflater;
119 ASSERT_TRUE(inflater.Initialize(15));
120 scoped_refptr<IOBufferWithSize> actual1, actual2;
121
122 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
123 ASSERT_TRUE(inflater.Finish());
124 actual1 = inflater.GetOutput(inflater.CurrentOutputSize());
125 ASSERT_TRUE(actual1.get());
126 EXPECT_EQ("Hello", ToString(actual1.get()));
127 EXPECT_EQ(0u, inflater.CurrentOutputSize());
128
129 // Reset the stream with a block [BFINAL = 1, BTYPE = 00, LEN = 0]
130 ASSERT_TRUE(inflater.AddBytes("\x01", 1));
131 ASSERT_TRUE(inflater.Finish());
132 ASSERT_EQ(0u, inflater.CurrentOutputSize());
133
134 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
135 ASSERT_TRUE(inflater.Finish());
136 actual2 = inflater.GetOutput(inflater.CurrentOutputSize());
137 ASSERT_TRUE(actual2.get());
138 EXPECT_EQ("Hello", ToString(actual2.get()));
139 EXPECT_EQ(0u, inflater.CurrentOutputSize());
140 }
141
TEST(WebSocketInflaterTest,ResetAndLostContext)142 TEST(WebSocketInflaterTest, ResetAndLostContext) {
143 WebSocketInflater inflater;
144 scoped_refptr<IOBufferWithSize> actual1, actual2;
145 ASSERT_TRUE(inflater.Initialize(15));
146
147 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
148 ASSERT_TRUE(inflater.Finish());
149 actual1 = inflater.GetOutput(inflater.CurrentOutputSize());
150 ASSERT_TRUE(actual1.get());
151 EXPECT_EQ("Hello", ToString(actual1.get()));
152 EXPECT_EQ(0u, inflater.CurrentOutputSize());
153
154 // Reset the stream with a block [BFINAL = 1, BTYPE = 00, LEN = 0]
155 ASSERT_TRUE(inflater.AddBytes("\x01", 1));
156 ASSERT_TRUE(inflater.Finish());
157 ASSERT_EQ(0u, inflater.CurrentOutputSize());
158
159 // The context is already reset.
160 ASSERT_FALSE(inflater.AddBytes("\xf2\x00\x11\x00\x00", 5));
161 }
162
TEST(WebSocketInflaterTest,CallAddBytesAndFinishWithoutGetOutput)163 TEST(WebSocketInflaterTest, CallAddBytesAndFinishWithoutGetOutput) {
164 WebSocketInflater inflater;
165 scoped_refptr<IOBufferWithSize> actual1, actual2;
166 ASSERT_TRUE(inflater.Initialize(15));
167
168 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
169 ASSERT_TRUE(inflater.Finish());
170 EXPECT_EQ(5u, inflater.CurrentOutputSize());
171
172 // This is a test for memory leak detectors.
173 }
174
TEST(WebSocketInflaterTest,CallAddBytesAndFinishWithoutGetOutputChoked)175 TEST(WebSocketInflaterTest, CallAddBytesAndFinishWithoutGetOutputChoked) {
176 WebSocketInflater inflater(1, 1);
177 scoped_refptr<IOBufferWithSize> actual1, actual2;
178 ASSERT_TRUE(inflater.Initialize(15));
179
180 ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7));
181 ASSERT_TRUE(inflater.Finish());
182 EXPECT_EQ(1u, inflater.CurrentOutputSize());
183
184 // This is a test for memory leak detectors.
185 }
186
TEST(WebSocketInflaterTest,LargeRandomDeflateInflate)187 TEST(WebSocketInflaterTest, LargeRandomDeflateInflate) {
188 const size_t size = 64 * 1024;
189 LinearCongruentialGenerator generator(133);
190 std::vector<char> input;
191 std::vector<char> output;
192 scoped_refptr<IOBufferWithSize> compressed;
193
194 WebSocketDeflater deflater(WebSocketDeflater::TAKE_OVER_CONTEXT);
195 ASSERT_TRUE(deflater.Initialize(8));
196 WebSocketInflater inflater(256, 256);
197 ASSERT_TRUE(inflater.Initialize(8));
198
199 for (size_t i = 0; i < size; ++i)
200 input.push_back(static_cast<char>(generator.Generate()));
201
202 ASSERT_TRUE(deflater.AddBytes(input.data(), input.size()));
203 ASSERT_TRUE(deflater.Finish());
204
205 compressed = deflater.GetOutput(deflater.CurrentOutputSize());
206
207 ASSERT_TRUE(compressed.get());
208 ASSERT_EQ(0u, deflater.CurrentOutputSize());
209
210 ASSERT_TRUE(inflater.AddBytes(compressed->data(), compressed->size()));
211 ASSERT_TRUE(inflater.Finish());
212
213 while (inflater.CurrentOutputSize() > 0) {
214 scoped_refptr<IOBufferWithSize> uncompressed =
215 inflater.GetOutput(inflater.CurrentOutputSize());
216 ASSERT_TRUE(uncompressed.get());
217 output.insert(output.end(),
218 uncompressed->data(),
219 uncompressed->data() + uncompressed->size());
220 }
221
222 EXPECT_EQ(output, input);
223 }
224
225 } // unnamed namespace
226
227 } // namespace net
228