• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.cloud.automl.v1beta1;
18 
19 import com.google.api.gax.core.NoCredentialsProvider;
20 import com.google.api.gax.grpc.GaxGrpcProperties;
21 import com.google.api.gax.grpc.testing.LocalChannelProvider;
22 import com.google.api.gax.grpc.testing.MockGrpcService;
23 import com.google.api.gax.grpc.testing.MockServiceHelper;
24 import com.google.api.gax.rpc.ApiClientHeaderProvider;
25 import com.google.api.gax.rpc.InvalidArgumentException;
26 import com.google.api.gax.rpc.StatusCode;
27 import com.google.longrunning.Operation;
28 import com.google.protobuf.AbstractMessage;
29 import com.google.protobuf.Any;
30 import io.grpc.StatusRuntimeException;
31 import java.io.IOException;
32 import java.util.ArrayList;
33 import java.util.Arrays;
34 import java.util.HashMap;
35 import java.util.List;
36 import java.util.Map;
37 import java.util.UUID;
38 import java.util.concurrent.ExecutionException;
39 import javax.annotation.Generated;
40 import org.junit.After;
41 import org.junit.AfterClass;
42 import org.junit.Assert;
43 import org.junit.Before;
44 import org.junit.BeforeClass;
45 import org.junit.Test;
46 
47 @Generated("by gapic-generator-java")
48 public class PredictionServiceClientTest {
49   private static MockPredictionService mockPredictionService;
50   private static MockServiceHelper mockServiceHelper;
51   private LocalChannelProvider channelProvider;
52   private PredictionServiceClient client;
53 
54   @BeforeClass
startStaticServer()55   public static void startStaticServer() {
56     mockPredictionService = new MockPredictionService();
57     mockServiceHelper =
58         new MockServiceHelper(
59             UUID.randomUUID().toString(), Arrays.<MockGrpcService>asList(mockPredictionService));
60     mockServiceHelper.start();
61   }
62 
63   @AfterClass
stopServer()64   public static void stopServer() {
65     mockServiceHelper.stop();
66   }
67 
68   @Before
setUp()69   public void setUp() throws IOException {
70     mockServiceHelper.reset();
71     channelProvider = mockServiceHelper.createChannelProvider();
72     PredictionServiceSettings settings =
73         PredictionServiceSettings.newBuilder()
74             .setTransportChannelProvider(channelProvider)
75             .setCredentialsProvider(NoCredentialsProvider.create())
76             .build();
77     client = PredictionServiceClient.create(settings);
78   }
79 
80   @After
tearDown()81   public void tearDown() throws Exception {
82     client.close();
83   }
84 
85   @Test
predictTest()86   public void predictTest() throws Exception {
87     PredictResponse expectedResponse =
88         PredictResponse.newBuilder()
89             .addAllPayload(new ArrayList<AnnotationPayload>())
90             .setPreprocessedInput(ExamplePayload.newBuilder().build())
91             .putAllMetadata(new HashMap<String, String>())
92             .build();
93     mockPredictionService.addResponse(expectedResponse);
94 
95     ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]");
96     ExamplePayload payload = ExamplePayload.newBuilder().build();
97     Map<String, String> params = new HashMap<>();
98 
99     PredictResponse actualResponse = client.predict(name, payload, params);
100     Assert.assertEquals(expectedResponse, actualResponse);
101 
102     List<AbstractMessage> actualRequests = mockPredictionService.getRequests();
103     Assert.assertEquals(1, actualRequests.size());
104     PredictRequest actualRequest = ((PredictRequest) actualRequests.get(0));
105 
106     Assert.assertEquals(name.toString(), actualRequest.getName());
107     Assert.assertEquals(payload, actualRequest.getPayload());
108     Assert.assertEquals(params, actualRequest.getParamsMap());
109     Assert.assertTrue(
110         channelProvider.isHeaderSent(
111             ApiClientHeaderProvider.getDefaultApiClientHeaderKey(),
112             GaxGrpcProperties.getDefaultApiClientHeaderPattern()));
113   }
114 
115   @Test
predictExceptionTest()116   public void predictExceptionTest() throws Exception {
117     StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT);
118     mockPredictionService.addException(exception);
119 
120     try {
121       ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]");
122       ExamplePayload payload = ExamplePayload.newBuilder().build();
123       Map<String, String> params = new HashMap<>();
124       client.predict(name, payload, params);
125       Assert.fail("No exception raised");
126     } catch (InvalidArgumentException e) {
127       // Expected exception.
128     }
129   }
130 
131   @Test
predictTest2()132   public void predictTest2() throws Exception {
133     PredictResponse expectedResponse =
134         PredictResponse.newBuilder()
135             .addAllPayload(new ArrayList<AnnotationPayload>())
136             .setPreprocessedInput(ExamplePayload.newBuilder().build())
137             .putAllMetadata(new HashMap<String, String>())
138             .build();
139     mockPredictionService.addResponse(expectedResponse);
140 
141     String name = "name3373707";
142     ExamplePayload payload = ExamplePayload.newBuilder().build();
143     Map<String, String> params = new HashMap<>();
144 
145     PredictResponse actualResponse = client.predict(name, payload, params);
146     Assert.assertEquals(expectedResponse, actualResponse);
147 
148     List<AbstractMessage> actualRequests = mockPredictionService.getRequests();
149     Assert.assertEquals(1, actualRequests.size());
150     PredictRequest actualRequest = ((PredictRequest) actualRequests.get(0));
151 
152     Assert.assertEquals(name, actualRequest.getName());
153     Assert.assertEquals(payload, actualRequest.getPayload());
154     Assert.assertEquals(params, actualRequest.getParamsMap());
155     Assert.assertTrue(
156         channelProvider.isHeaderSent(
157             ApiClientHeaderProvider.getDefaultApiClientHeaderKey(),
158             GaxGrpcProperties.getDefaultApiClientHeaderPattern()));
159   }
160 
161   @Test
predictExceptionTest2()162   public void predictExceptionTest2() throws Exception {
163     StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT);
164     mockPredictionService.addException(exception);
165 
166     try {
167       String name = "name3373707";
168       ExamplePayload payload = ExamplePayload.newBuilder().build();
169       Map<String, String> params = new HashMap<>();
170       client.predict(name, payload, params);
171       Assert.fail("No exception raised");
172     } catch (InvalidArgumentException e) {
173       // Expected exception.
174     }
175   }
176 
177   @Test
batchPredictTest()178   public void batchPredictTest() throws Exception {
179     BatchPredictResult expectedResponse =
180         BatchPredictResult.newBuilder().putAllMetadata(new HashMap<String, String>()).build();
181     Operation resultOperation =
182         Operation.newBuilder()
183             .setName("batchPredictTest")
184             .setDone(true)
185             .setResponse(Any.pack(expectedResponse))
186             .build();
187     mockPredictionService.addResponse(resultOperation);
188 
189     ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]");
190     BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build();
191     BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build();
192     Map<String, String> params = new HashMap<>();
193 
194     BatchPredictResult actualResponse =
195         client.batchPredictAsync(name, inputConfig, outputConfig, params).get();
196     Assert.assertEquals(expectedResponse, actualResponse);
197 
198     List<AbstractMessage> actualRequests = mockPredictionService.getRequests();
199     Assert.assertEquals(1, actualRequests.size());
200     BatchPredictRequest actualRequest = ((BatchPredictRequest) actualRequests.get(0));
201 
202     Assert.assertEquals(name.toString(), actualRequest.getName());
203     Assert.assertEquals(inputConfig, actualRequest.getInputConfig());
204     Assert.assertEquals(outputConfig, actualRequest.getOutputConfig());
205     Assert.assertEquals(params, actualRequest.getParamsMap());
206     Assert.assertTrue(
207         channelProvider.isHeaderSent(
208             ApiClientHeaderProvider.getDefaultApiClientHeaderKey(),
209             GaxGrpcProperties.getDefaultApiClientHeaderPattern()));
210   }
211 
212   @Test
batchPredictExceptionTest()213   public void batchPredictExceptionTest() throws Exception {
214     StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT);
215     mockPredictionService.addException(exception);
216 
217     try {
218       ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]");
219       BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build();
220       BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build();
221       Map<String, String> params = new HashMap<>();
222       client.batchPredictAsync(name, inputConfig, outputConfig, params).get();
223       Assert.fail("No exception raised");
224     } catch (ExecutionException e) {
225       Assert.assertEquals(InvalidArgumentException.class, e.getCause().getClass());
226       InvalidArgumentException apiException = ((InvalidArgumentException) e.getCause());
227       Assert.assertEquals(StatusCode.Code.INVALID_ARGUMENT, apiException.getStatusCode().getCode());
228     }
229   }
230 
231   @Test
batchPredictTest2()232   public void batchPredictTest2() throws Exception {
233     BatchPredictResult expectedResponse =
234         BatchPredictResult.newBuilder().putAllMetadata(new HashMap<String, String>()).build();
235     Operation resultOperation =
236         Operation.newBuilder()
237             .setName("batchPredictTest")
238             .setDone(true)
239             .setResponse(Any.pack(expectedResponse))
240             .build();
241     mockPredictionService.addResponse(resultOperation);
242 
243     String name = "name3373707";
244     BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build();
245     BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build();
246     Map<String, String> params = new HashMap<>();
247 
248     BatchPredictResult actualResponse =
249         client.batchPredictAsync(name, inputConfig, outputConfig, params).get();
250     Assert.assertEquals(expectedResponse, actualResponse);
251 
252     List<AbstractMessage> actualRequests = mockPredictionService.getRequests();
253     Assert.assertEquals(1, actualRequests.size());
254     BatchPredictRequest actualRequest = ((BatchPredictRequest) actualRequests.get(0));
255 
256     Assert.assertEquals(name, actualRequest.getName());
257     Assert.assertEquals(inputConfig, actualRequest.getInputConfig());
258     Assert.assertEquals(outputConfig, actualRequest.getOutputConfig());
259     Assert.assertEquals(params, actualRequest.getParamsMap());
260     Assert.assertTrue(
261         channelProvider.isHeaderSent(
262             ApiClientHeaderProvider.getDefaultApiClientHeaderKey(),
263             GaxGrpcProperties.getDefaultApiClientHeaderPattern()));
264   }
265 
266   @Test
batchPredictExceptionTest2()267   public void batchPredictExceptionTest2() throws Exception {
268     StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT);
269     mockPredictionService.addException(exception);
270 
271     try {
272       String name = "name3373707";
273       BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build();
274       BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build();
275       Map<String, String> params = new HashMap<>();
276       client.batchPredictAsync(name, inputConfig, outputConfig, params).get();
277       Assert.fail("No exception raised");
278     } catch (ExecutionException e) {
279       Assert.assertEquals(InvalidArgumentException.class, e.getCause().getClass());
280       InvalidArgumentException apiException = ((InvalidArgumentException) e.getCause());
281       Assert.assertEquals(StatusCode.Code.INVALID_ARGUMENT, apiException.getStatusCode().getCode());
282     }
283   }
284 }
285