1 /* 2 * Copyright 2016 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.okhttp; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertTrue; 23 import static org.mockito.Matchers.eq; 24 import static org.mockito.Matchers.isA; 25 import static org.mockito.Mockito.times; 26 import static org.mockito.Mockito.verify; 27 import static org.mockito.Mockito.verifyNoMoreInteractions; 28 29 import com.google.common.io.BaseEncoding; 30 import io.grpc.Metadata; 31 import io.grpc.MethodDescriptor; 32 import io.grpc.MethodDescriptor.MethodType; 33 import io.grpc.Status; 34 import io.grpc.internal.GrpcUtil; 35 import io.grpc.internal.NoopClientStreamListener; 36 import io.grpc.internal.StatsTraceContext; 37 import io.grpc.internal.TransportTracer; 38 import io.grpc.okhttp.internal.framed.ErrorCode; 39 import io.grpc.okhttp.internal.framed.Header; 40 import java.io.ByteArrayInputStream; 41 import java.nio.charset.Charset; 42 import java.util.List; 43 import java.util.concurrent.atomic.AtomicReference; 44 import org.junit.Before; 45 import org.junit.Test; 46 import org.junit.runner.RunWith; 47 import org.junit.runners.JUnit4; 48 import org.mockito.ArgumentCaptor; 49 import org.mockito.Captor; 50 import org.mockito.Mock; 51 import org.mockito.Mockito; 52 import org.mockito.MockitoAnnotations; 53 import org.mockito.invocation.InvocationOnMock; 54 import org.mockito.stubbing.Answer; 55 56 @RunWith(JUnit4.class) 57 public class OkHttpClientStreamTest { 58 private static final int MAX_MESSAGE_SIZE = 100; 59 60 @Mock private MethodDescriptor.Marshaller<Void> marshaller; 61 @Mock private AsyncFrameWriter frameWriter; 62 @Mock private OkHttpClientTransport transport; 63 @Mock private OutboundFlowController flowController; 64 @Captor private ArgumentCaptor<List<Header>> headersCaptor; 65 66 private final Object lock = new Object(); 67 private final TransportTracer transportTracer = new TransportTracer(); 68 69 private MethodDescriptor<?, ?> methodDescriptor; 70 private OkHttpClientStream stream; 71 72 @Before setUp()73 public void setUp() { 74 MockitoAnnotations.initMocks(this); 75 methodDescriptor = MethodDescriptor.<Void, Void>newBuilder() 76 .setType(MethodDescriptor.MethodType.UNARY) 77 .setFullMethodName("testService/test") 78 .setRequestMarshaller(marshaller) 79 .setResponseMarshaller(marshaller) 80 .build(); 81 82 stream = new OkHttpClientStream( 83 methodDescriptor, 84 new Metadata(), 85 frameWriter, 86 transport, 87 flowController, 88 lock, 89 MAX_MESSAGE_SIZE, 90 "localhost", 91 "userAgent", 92 StatsTraceContext.NOOP, 93 transportTracer); 94 } 95 96 @Test getType()97 public void getType() { 98 assertEquals(MethodType.UNARY, stream.getType()); 99 } 100 101 @Test cancel_notStarted()102 public void cancel_notStarted() { 103 final AtomicReference<Status> statusRef = new AtomicReference<Status>(); 104 stream.start(new BaseClientStreamListener() { 105 @Override 106 public void closed( 107 Status status, RpcProgress rpcProgress, Metadata trailers) { 108 statusRef.set(status); 109 assertTrue(Thread.holdsLock(lock)); 110 } 111 }); 112 113 stream.cancel(Status.CANCELLED); 114 115 assertEquals(Status.Code.CANCELLED, statusRef.get().getCode()); 116 } 117 118 @Test cancel_started()119 public void cancel_started() { 120 stream.start(new BaseClientStreamListener()); 121 stream.transportState().start(1234); 122 Mockito.doAnswer(new Answer<Void>() { 123 @Override 124 public Void answer(InvocationOnMock invocation) throws Throwable { 125 assertTrue(Thread.holdsLock(lock)); 126 return null; 127 } 128 }).when(transport).finishStream( 129 1234, Status.CANCELLED, PROCESSED, true, ErrorCode.CANCEL, null); 130 131 stream.cancel(Status.CANCELLED); 132 133 verify(transport).finishStream(1234, Status.CANCELLED, PROCESSED,true, ErrorCode.CANCEL, null); 134 } 135 136 @Test start_alreadyCancelled()137 public void start_alreadyCancelled() { 138 stream.start(new BaseClientStreamListener()); 139 stream.cancel(Status.CANCELLED); 140 141 stream.transportState().start(1234); 142 143 verifyNoMoreInteractions(frameWriter); 144 } 145 146 @Test start_userAgentRemoved()147 public void start_userAgentRemoved() { 148 Metadata metaData = new Metadata(); 149 metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); 150 stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, 151 flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application", 152 StatsTraceContext.NOOP, transportTracer); 153 stream.start(new BaseClientStreamListener()); 154 stream.transportState().start(3); 155 156 verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); 157 assertThat(headersCaptor.getValue()) 158 .contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application")); 159 } 160 161 @Test start_headerFieldOrder()162 public void start_headerFieldOrder() { 163 Metadata metaData = new Metadata(); 164 metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application"); 165 stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport, 166 flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application", 167 StatsTraceContext.NOOP, transportTracer); 168 stream.start(new BaseClientStreamListener()); 169 stream.transportState().start(3); 170 171 verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); 172 assertThat(headersCaptor.getValue()).containsExactly( 173 Headers.SCHEME_HEADER, 174 Headers.METHOD_HEADER, 175 new Header(Header.TARGET_AUTHORITY, "localhost"), 176 new Header(Header.TARGET_PATH, "/" + methodDescriptor.getFullMethodName()), 177 new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"), 178 Headers.CONTENT_TYPE_HEADER, 179 Headers.TE_HEADER) 180 .inOrder(); 181 } 182 183 @Test getUnaryRequest()184 public void getUnaryRequest() { 185 MethodDescriptor<?, ?> getMethod = MethodDescriptor.<Void, Void>newBuilder() 186 .setType(MethodDescriptor.MethodType.UNARY) 187 .setFullMethodName("service/method") 188 .setIdempotent(true) 189 .setSafe(true) 190 .setRequestMarshaller(marshaller) 191 .setResponseMarshaller(marshaller) 192 .build(); 193 stream = new OkHttpClientStream(getMethod, new Metadata(), frameWriter, transport, 194 flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application", 195 StatsTraceContext.NOOP, transportTracer); 196 stream.start(new BaseClientStreamListener()); 197 198 // GET streams send headers after halfClose is called. 199 verify(frameWriter, times(0)).synStream( 200 eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); 201 verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); 202 203 byte[] msg = "request".getBytes(Charset.forName("UTF-8")); 204 stream.writeMessage(new ByteArrayInputStream(msg)); 205 stream.halfClose(); 206 verify(transport).streamReadyToStart(eq(stream)); 207 stream.transportState().start(3); 208 209 verify(frameWriter) 210 .synStream(eq(true), eq(false), eq(3), eq(0), headersCaptor.capture()); 211 assertThat(headersCaptor.getValue()).contains(Headers.METHOD_GET_HEADER); 212 assertThat(headersCaptor.getValue()).contains( 213 new Header(Header.TARGET_PATH, "/" + getMethod.getFullMethodName() + "?" 214 + BaseEncoding.base64().encode(msg))); 215 } 216 217 // TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions 218 // of it. 219 private static class BaseClientStreamListener extends NoopClientStreamListener { 220 221 @Override messagesAvailable(MessageProducer producer)222 public void messagesAvailable(MessageProducer producer) { 223 while (producer.next() != null) {} 224 } 225 } 226 } 227 228