• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.okhttp;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertTrue;
23 import static org.mockito.Matchers.eq;
24 import static org.mockito.Matchers.isA;
25 import static org.mockito.Mockito.times;
26 import static org.mockito.Mockito.verify;
27 import static org.mockito.Mockito.verifyNoMoreInteractions;
28 
29 import com.google.common.io.BaseEncoding;
30 import io.grpc.Metadata;
31 import io.grpc.MethodDescriptor;
32 import io.grpc.MethodDescriptor.MethodType;
33 import io.grpc.Status;
34 import io.grpc.internal.GrpcUtil;
35 import io.grpc.internal.NoopClientStreamListener;
36 import io.grpc.internal.StatsTraceContext;
37 import io.grpc.internal.TransportTracer;
38 import io.grpc.okhttp.internal.framed.ErrorCode;
39 import io.grpc.okhttp.internal.framed.Header;
40 import java.io.ByteArrayInputStream;
41 import java.nio.charset.Charset;
42 import java.util.List;
43 import java.util.concurrent.atomic.AtomicReference;
44 import org.junit.Before;
45 import org.junit.Test;
46 import org.junit.runner.RunWith;
47 import org.junit.runners.JUnit4;
48 import org.mockito.ArgumentCaptor;
49 import org.mockito.Captor;
50 import org.mockito.Mock;
51 import org.mockito.Mockito;
52 import org.mockito.MockitoAnnotations;
53 import org.mockito.invocation.InvocationOnMock;
54 import org.mockito.stubbing.Answer;
55 
56 @RunWith(JUnit4.class)
57 public class OkHttpClientStreamTest {
58   private static final int MAX_MESSAGE_SIZE = 100;
59 
60   @Mock private MethodDescriptor.Marshaller<Void> marshaller;
61   @Mock private AsyncFrameWriter frameWriter;
62   @Mock private OkHttpClientTransport transport;
63   @Mock private OutboundFlowController flowController;
64   @Captor private ArgumentCaptor<List<Header>> headersCaptor;
65 
66   private final Object lock = new Object();
67   private final TransportTracer transportTracer = new TransportTracer();
68 
69   private MethodDescriptor<?, ?> methodDescriptor;
70   private OkHttpClientStream stream;
71 
72   @Before
setUp()73   public void setUp() {
74     MockitoAnnotations.initMocks(this);
75     methodDescriptor = MethodDescriptor.<Void, Void>newBuilder()
76         .setType(MethodDescriptor.MethodType.UNARY)
77         .setFullMethodName("testService/test")
78         .setRequestMarshaller(marshaller)
79         .setResponseMarshaller(marshaller)
80         .build();
81 
82     stream = new OkHttpClientStream(
83         methodDescriptor,
84         new Metadata(),
85         frameWriter,
86         transport,
87         flowController,
88         lock,
89         MAX_MESSAGE_SIZE,
90         "localhost",
91         "userAgent",
92         StatsTraceContext.NOOP,
93         transportTracer);
94   }
95 
96   @Test
getType()97   public void getType() {
98     assertEquals(MethodType.UNARY, stream.getType());
99   }
100 
101   @Test
cancel_notStarted()102   public void cancel_notStarted() {
103     final AtomicReference<Status> statusRef = new AtomicReference<Status>();
104     stream.start(new BaseClientStreamListener() {
105       @Override
106       public void closed(
107           Status status, RpcProgress rpcProgress, Metadata trailers) {
108         statusRef.set(status);
109         assertTrue(Thread.holdsLock(lock));
110       }
111     });
112 
113     stream.cancel(Status.CANCELLED);
114 
115     assertEquals(Status.Code.CANCELLED, statusRef.get().getCode());
116   }
117 
118   @Test
cancel_started()119   public void cancel_started() {
120     stream.start(new BaseClientStreamListener());
121     stream.transportState().start(1234);
122     Mockito.doAnswer(new Answer<Void>() {
123       @Override
124       public Void answer(InvocationOnMock invocation) throws Throwable {
125         assertTrue(Thread.holdsLock(lock));
126         return null;
127       }
128     }).when(transport).finishStream(
129         1234, Status.CANCELLED, PROCESSED, true, ErrorCode.CANCEL, null);
130 
131     stream.cancel(Status.CANCELLED);
132 
133     verify(transport).finishStream(1234, Status.CANCELLED, PROCESSED,true, ErrorCode.CANCEL, null);
134   }
135 
136   @Test
start_alreadyCancelled()137   public void start_alreadyCancelled() {
138     stream.start(new BaseClientStreamListener());
139     stream.cancel(Status.CANCELLED);
140 
141     stream.transportState().start(1234);
142 
143     verifyNoMoreInteractions(frameWriter);
144   }
145 
146   @Test
start_userAgentRemoved()147   public void start_userAgentRemoved() {
148     Metadata metaData = new Metadata();
149     metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
150     stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
151         flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application",
152         StatsTraceContext.NOOP, transportTracer);
153     stream.start(new BaseClientStreamListener());
154     stream.transportState().start(3);
155 
156     verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
157     assertThat(headersCaptor.getValue())
158         .contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"));
159   }
160 
161   @Test
start_headerFieldOrder()162   public void start_headerFieldOrder() {
163     Metadata metaData = new Metadata();
164     metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
165     stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
166         flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application",
167         StatsTraceContext.NOOP, transportTracer);
168     stream.start(new BaseClientStreamListener());
169     stream.transportState().start(3);
170 
171     verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
172     assertThat(headersCaptor.getValue()).containsExactly(
173         Headers.SCHEME_HEADER,
174         Headers.METHOD_HEADER,
175         new Header(Header.TARGET_AUTHORITY, "localhost"),
176         new Header(Header.TARGET_PATH, "/" + methodDescriptor.getFullMethodName()),
177         new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"),
178         Headers.CONTENT_TYPE_HEADER,
179         Headers.TE_HEADER)
180             .inOrder();
181   }
182 
183   @Test
getUnaryRequest()184   public void getUnaryRequest() {
185     MethodDescriptor<?, ?> getMethod = MethodDescriptor.<Void, Void>newBuilder()
186         .setType(MethodDescriptor.MethodType.UNARY)
187         .setFullMethodName("service/method")
188         .setIdempotent(true)
189         .setSafe(true)
190         .setRequestMarshaller(marshaller)
191         .setResponseMarshaller(marshaller)
192         .build();
193     stream = new OkHttpClientStream(getMethod, new Metadata(), frameWriter, transport,
194         flowController, lock, MAX_MESSAGE_SIZE, "localhost", "good-application",
195         StatsTraceContext.NOOP, transportTracer);
196     stream.start(new BaseClientStreamListener());
197 
198     // GET streams send headers after halfClose is called.
199     verify(frameWriter, times(0)).synStream(
200         eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
201     verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class));
202 
203     byte[] msg = "request".getBytes(Charset.forName("UTF-8"));
204     stream.writeMessage(new ByteArrayInputStream(msg));
205     stream.halfClose();
206     verify(transport).streamReadyToStart(eq(stream));
207     stream.transportState().start(3);
208 
209     verify(frameWriter)
210         .synStream(eq(true), eq(false), eq(3), eq(0), headersCaptor.capture());
211     assertThat(headersCaptor.getValue()).contains(Headers.METHOD_GET_HEADER);
212     assertThat(headersCaptor.getValue()).contains(
213         new Header(Header.TARGET_PATH, "/" + getMethod.getFullMethodName() + "?"
214             + BaseEncoding.base64().encode(msg)));
215   }
216 
217   // TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions
218   // of it.
219   private static class BaseClientStreamListener extends NoopClientStreamListener {
220 
221     @Override
messagesAvailable(MessageProducer producer)222     public void messagesAvailable(MessageProducer producer) {
223       while (producer.next() != null) {}
224     }
225   }
226 }
227 
228