1 /* 2 * Copyright 2019 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.netty; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertFalse; 22 import static org.junit.Assert.assertNull; 23 import static org.junit.Assert.assertTrue; 24 import static org.junit.Assert.fail; 25 26 import io.grpc.Status; 27 import io.grpc.Status.Code; 28 import io.netty.bootstrap.Bootstrap; 29 import io.netty.bootstrap.ServerBootstrap; 30 import io.netty.buffer.Unpooled; 31 import io.netty.channel.Channel; 32 import io.netty.channel.ChannelDuplexHandler; 33 import io.netty.channel.ChannelFuture; 34 import io.netty.channel.ChannelHandlerAdapter; 35 import io.netty.channel.ChannelHandlerContext; 36 import io.netty.channel.ChannelOutboundHandlerAdapter; 37 import io.netty.channel.ChannelPromise; 38 import io.netty.channel.DefaultEventLoop; 39 import io.netty.channel.EventLoop; 40 import io.netty.channel.local.LocalAddress; 41 import io.netty.channel.local.LocalChannel; 42 import io.netty.channel.local.LocalServerChannel; 43 import java.net.ConnectException; 44 import java.util.concurrent.TimeUnit; 45 import java.util.concurrent.atomic.AtomicBoolean; 46 import java.util.concurrent.atomic.AtomicInteger; 47 import java.util.concurrent.atomic.AtomicReference; 48 import org.junit.After; 49 import org.junit.Rule; 50 import org.junit.Test; 51 import org.junit.rules.DisableOnDebug; 52 import org.junit.rules.TestRule; 53 import org.junit.rules.Timeout; 54 import org.junit.runner.RunWith; 55 import org.junit.runners.JUnit4; 56 57 /** 58 * Tests for {@link WriteBufferingAndExceptionHandler}. 59 */ 60 @RunWith(JUnit4.class) 61 public class WriteBufferingAndExceptionHandlerTest { 62 63 private static final long TIMEOUT_SECONDS = 10; 64 65 @Rule 66 public final TestRule timeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS)); 67 68 private final EventLoop group = new DefaultEventLoop(); 69 private Channel chan; 70 private Channel server; 71 72 @After tearDown()73 public void tearDown() throws InterruptedException { 74 if (server != null) { 75 server.close().sync(); 76 } 77 if (chan != null) { 78 chan.close().sync(); 79 } 80 group.shutdownGracefully(0, 10, TimeUnit.SECONDS).sync(); 81 } 82 83 @Test connectionFailuresPropagated()84 public void connectionFailuresPropagated() throws Exception { 85 WriteBufferingAndExceptionHandler handler = 86 new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() {}); 87 ChannelFuture cf = new Bootstrap() 88 .channel(LocalChannel.class) 89 .handler(handler) 90 .group(group) 91 .register(); 92 chan = cf.channel(); 93 cf.sync(); 94 // Write before connect. In the event connect fails, the pipeline is torn down and the handler 95 // won't be able to fail the writes with the correct exception. 96 ChannelFuture wf = chan.writeAndFlush(new Object()); 97 chan.connect(new LocalAddress("bogus")); 98 99 try { 100 wf.sync(); 101 fail(); 102 } catch (Exception e) { 103 assertThat(e).isInstanceOf(ConnectException.class); 104 assertThat(e).hasMessageThat().contains("connection refused"); 105 } 106 } 107 108 @Test channelInactiveFailuresPropagated()109 public void channelInactiveFailuresPropagated() throws Exception { 110 WriteBufferingAndExceptionHandler handler = 111 new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() {}); 112 LocalAddress addr = new LocalAddress("local"); 113 ChannelFuture cf = new Bootstrap() 114 .channel(LocalChannel.class) 115 .handler(handler) 116 .group(group) 117 .register(); 118 chan = cf.channel(); 119 cf.sync(); 120 ChannelFuture sf = new ServerBootstrap() 121 .channel(LocalServerChannel.class) 122 .childHandler(new ChannelHandlerAdapter() {}) 123 .group(group) 124 .bind(addr); 125 server = sf.channel(); 126 sf.sync(); 127 128 ChannelFuture wf = chan.writeAndFlush(new Object()); 129 chan.connect(addr); 130 chan.pipeline().fireChannelInactive(); 131 132 try { 133 wf.sync(); 134 fail(); 135 } catch (Exception e) { 136 Status status = Status.fromThrowable(e); 137 assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); 138 assertThat(status.getDescription()) 139 .contains("Connection closed while performing protocol negotiation"); 140 } 141 } 142 143 @Test channelCloseFailuresPropagated()144 public void channelCloseFailuresPropagated() throws Exception { 145 WriteBufferingAndExceptionHandler handler = 146 new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() {}); 147 LocalAddress addr = new LocalAddress("local"); 148 ChannelFuture cf = new Bootstrap() 149 .channel(LocalChannel.class) 150 .handler(handler) 151 .group(group) 152 .register(); 153 chan = cf.channel(); 154 cf.sync(); 155 ChannelFuture sf = new ServerBootstrap() 156 .channel(LocalServerChannel.class) 157 .childHandler(new ChannelHandlerAdapter() {}) 158 .group(group) 159 .bind(addr); 160 server = sf.channel(); 161 sf.sync(); 162 163 ChannelFuture wf = chan.writeAndFlush(new Object()); 164 chan.connect(addr); 165 chan.close(); 166 167 try { 168 wf.sync(); 169 fail(); 170 } catch (Exception e) { 171 Status status = Status.fromThrowable(e); 172 assertThat(status.getCode()).isEqualTo(Code.UNAVAILABLE); 173 assertThat(status.getDescription()) 174 .contains("Connection closing while performing protocol negotiation"); 175 } 176 } 177 178 @Test uncaughtExceptionFailuresPropagated()179 public void uncaughtExceptionFailuresPropagated() throws Exception { 180 WriteBufferingAndExceptionHandler handler = 181 new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() {}); 182 LocalAddress addr = new LocalAddress("local"); 183 ChannelFuture cf = new Bootstrap() 184 .channel(LocalChannel.class) 185 .handler(handler) 186 .group(group) 187 .register(); 188 chan = cf.channel(); 189 cf.sync(); 190 ChannelFuture sf = new ServerBootstrap() 191 .channel(LocalServerChannel.class) 192 .childHandler(new ChannelHandlerAdapter() {}) 193 .group(group) 194 .bind(addr); 195 server = sf.channel(); 196 sf.sync(); 197 198 ChannelFuture wf = chan.writeAndFlush(new Object()); 199 chan.connect(addr); 200 chan.pipeline().fireExceptionCaught(Status.ABORTED.withDescription("zap").asRuntimeException()); 201 202 try { 203 wf.sync(); 204 fail(); 205 } catch (Exception e) { 206 Status status = Status.fromThrowable(e); 207 assertThat(status.getCode()).isEqualTo(Code.ABORTED); 208 assertThat(status.getDescription()).contains("zap"); 209 } 210 } 211 212 @Test uncaughtException_closeAtMostOnce()213 public void uncaughtException_closeAtMostOnce() throws Exception { 214 final AtomicInteger closes = new AtomicInteger(); 215 WriteBufferingAndExceptionHandler handler = 216 new WriteBufferingAndExceptionHandler(new ChannelDuplexHandler() { 217 @Override 218 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { 219 closes.getAndIncrement(); 220 // Simulates a loop between this handler and the WriteBufferingAndExceptionHandler. 221 ctx.fireExceptionCaught(Status.ABORTED.withDescription("zap").asRuntimeException()); 222 super.close(ctx, promise); 223 } 224 }); 225 LocalAddress addr = new LocalAddress("local"); 226 ChannelFuture cf = new Bootstrap() 227 .channel(LocalChannel.class) 228 .handler(handler) 229 .group(group) 230 .register(); 231 chan = cf.channel(); 232 cf.sync(); 233 ChannelFuture sf = new ServerBootstrap() 234 .channel(LocalServerChannel.class) 235 .childHandler(new ChannelHandlerAdapter() {}) 236 .group(group) 237 .bind(addr); 238 server = sf.channel(); 239 sf.sync(); 240 241 chan.connect(addr).sync(); 242 chan.close().sync(); 243 assertEquals(1, closes.get()); 244 } 245 246 @Test handlerRemovedFailuresPropagated()247 public void handlerRemovedFailuresPropagated() throws Exception { 248 WriteBufferingAndExceptionHandler handler = 249 new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() { 250 @Override 251 public void handlerRemoved(ChannelHandlerContext ctx) { 252 ctx.pipeline().remove( 253 ctx.pipeline().context(WriteBufferingAndExceptionHandler.class).name()); 254 } 255 }); 256 LocalAddress addr = new LocalAddress("local"); 257 ChannelFuture cf = new Bootstrap() 258 .channel(LocalChannel.class) 259 .handler(handler) 260 .group(group) 261 .register(); 262 chan = cf.channel(); 263 cf.sync(); 264 ChannelFuture sf = new ServerBootstrap() 265 .channel(LocalServerChannel.class) 266 .childHandler(new ChannelHandlerAdapter() {}) 267 .group(group) 268 .bind(addr); 269 server = sf.channel(); 270 sf.sync(); 271 272 chan.connect(addr); 273 ChannelFuture wf = chan.writeAndFlush(new Object()); 274 chan.pipeline().removeFirst(); 275 276 try { 277 wf.sync(); 278 fail(); 279 } catch (Exception e) { 280 Status status = Status.fromThrowable(e); 281 assertThat(status.getCode()).isEqualTo(Code.INTERNAL); 282 assertThat(status.getDescription()).contains("Buffer removed"); 283 } 284 } 285 286 @Test writesBuffered()287 public void writesBuffered() throws Exception { 288 final AtomicBoolean handlerAdded = new AtomicBoolean(); 289 final AtomicBoolean flush = new AtomicBoolean(); 290 final AtomicReference<Object> write = new AtomicReference<>(); 291 final WriteBufferingAndExceptionHandler handler = 292 new WriteBufferingAndExceptionHandler(new ChannelOutboundHandlerAdapter() { 293 @Override 294 public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 295 assertFalse(handlerAdded.getAndSet(true)); 296 super.handlerAdded(ctx); 297 } 298 299 @Override 300 public void flush(ChannelHandlerContext ctx) throws Exception { 301 assertFalse(flush.getAndSet(true)); 302 super.flush(ctx); 303 } 304 305 @Override 306 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { 307 assertNull(write.getAndSet(msg)); 308 promise.setSuccess(); 309 } 310 }); 311 LocalAddress addr = new LocalAddress("local"); 312 ChannelFuture cf = new Bootstrap() 313 .channel(LocalChannel.class) 314 .handler(handler) 315 .group(group) 316 .register(); 317 chan = cf.channel(); 318 cf.sync(); 319 ChannelFuture sf = new ServerBootstrap() 320 .channel(LocalServerChannel.class) 321 .childHandler(new ChannelHandlerAdapter() {}) 322 .group(group) 323 .bind(addr); 324 server = sf.channel(); 325 sf.sync(); 326 327 assertTrue(handlerAdded.get()); 328 329 chan.write(new Object()); 330 chan.connect(addr).sync(); 331 assertNull(write.get()); 332 333 chan.flush(); 334 assertNull(write.get()); 335 assertFalse(flush.get()); 336 337 assertThat(chan.pipeline().context(handler)).isNotNull(); 338 chan.eventLoop().submit(new Runnable() { 339 @Override 340 public void run() { 341 handler.writeBufferedAndRemove(chan.pipeline().context(handler)); 342 } 343 }).sync(); 344 345 assertThat(chan.pipeline().context(handler)).isNull(); 346 assertThat(write.get().getClass()).isSameInstanceAs(Object.class); 347 assertTrue(flush.get()); 348 assertThat(chan.pipeline().toMap().values()).doesNotContain(handler); 349 } 350 351 @Test uncaughtReadFails()352 public void uncaughtReadFails() throws Exception { 353 WriteBufferingAndExceptionHandler handler = 354 new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() {}); 355 LocalAddress addr = new LocalAddress("local"); 356 ChannelFuture cf = new Bootstrap() 357 .channel(LocalChannel.class) 358 .handler(handler) 359 .group(group) 360 .register(); 361 chan = cf.channel(); 362 cf.sync(); 363 ChannelFuture sf = new ServerBootstrap() 364 .channel(LocalServerChannel.class) 365 .childHandler(new ChannelHandlerAdapter() {}) 366 .group(group) 367 .bind(addr); 368 server = sf.channel(); 369 sf.sync(); 370 371 ChannelFuture wf = chan.writeAndFlush(new Object()); 372 chan.connect(addr); 373 chan.pipeline().fireChannelRead(Unpooled.copiedBuffer(new byte[] {'a'})); 374 375 try { 376 wf.sync(); 377 fail(); 378 } catch (Exception e) { 379 Status status = Status.fromThrowable(e); 380 assertThat(status.getCode()).isEqualTo(Code.INTERNAL); 381 assertThat(status.getDescription()).contains("channelRead() missed"); 382 } 383 } 384 } 385