1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * https://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.google.cloud.automl.v1beta1; 18 19 import com.google.api.gax.core.NoCredentialsProvider; 20 import com.google.api.gax.grpc.GaxGrpcProperties; 21 import com.google.api.gax.grpc.testing.LocalChannelProvider; 22 import com.google.api.gax.grpc.testing.MockGrpcService; 23 import com.google.api.gax.grpc.testing.MockServiceHelper; 24 import com.google.api.gax.rpc.ApiClientHeaderProvider; 25 import com.google.api.gax.rpc.InvalidArgumentException; 26 import com.google.api.gax.rpc.StatusCode; 27 import com.google.longrunning.Operation; 28 import com.google.protobuf.AbstractMessage; 29 import com.google.protobuf.Any; 30 import io.grpc.StatusRuntimeException; 31 import java.io.IOException; 32 import java.util.ArrayList; 33 import java.util.Arrays; 34 import java.util.HashMap; 35 import java.util.List; 36 import java.util.Map; 37 import java.util.UUID; 38 import java.util.concurrent.ExecutionException; 39 import javax.annotation.Generated; 40 import org.junit.After; 41 import org.junit.AfterClass; 42 import org.junit.Assert; 43 import org.junit.Before; 44 import org.junit.BeforeClass; 45 import org.junit.Test; 46 47 @Generated("by gapic-generator-java") 48 public class PredictionServiceClientTest { 49 private static MockPredictionService mockPredictionService; 50 private static MockServiceHelper mockServiceHelper; 51 private LocalChannelProvider channelProvider; 52 private PredictionServiceClient client; 53 54 @BeforeClass startStaticServer()55 public static void startStaticServer() { 56 mockPredictionService = new MockPredictionService(); 57 mockServiceHelper = 58 new MockServiceHelper( 59 UUID.randomUUID().toString(), Arrays.<MockGrpcService>asList(mockPredictionService)); 60 mockServiceHelper.start(); 61 } 62 63 @AfterClass stopServer()64 public static void stopServer() { 65 mockServiceHelper.stop(); 66 } 67 68 @Before setUp()69 public void setUp() throws IOException { 70 mockServiceHelper.reset(); 71 channelProvider = mockServiceHelper.createChannelProvider(); 72 PredictionServiceSettings settings = 73 PredictionServiceSettings.newBuilder() 74 .setTransportChannelProvider(channelProvider) 75 .setCredentialsProvider(NoCredentialsProvider.create()) 76 .build(); 77 client = PredictionServiceClient.create(settings); 78 } 79 80 @After tearDown()81 public void tearDown() throws Exception { 82 client.close(); 83 } 84 85 @Test predictTest()86 public void predictTest() throws Exception { 87 PredictResponse expectedResponse = 88 PredictResponse.newBuilder() 89 .addAllPayload(new ArrayList<AnnotationPayload>()) 90 .setPreprocessedInput(ExamplePayload.newBuilder().build()) 91 .putAllMetadata(new HashMap<String, String>()) 92 .build(); 93 mockPredictionService.addResponse(expectedResponse); 94 95 ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]"); 96 ExamplePayload payload = ExamplePayload.newBuilder().build(); 97 Map<String, String> params = new HashMap<>(); 98 99 PredictResponse actualResponse = client.predict(name, payload, params); 100 Assert.assertEquals(expectedResponse, actualResponse); 101 102 List<AbstractMessage> actualRequests = mockPredictionService.getRequests(); 103 Assert.assertEquals(1, actualRequests.size()); 104 PredictRequest actualRequest = ((PredictRequest) actualRequests.get(0)); 105 106 Assert.assertEquals(name.toString(), actualRequest.getName()); 107 Assert.assertEquals(payload, actualRequest.getPayload()); 108 Assert.assertEquals(params, actualRequest.getParamsMap()); 109 Assert.assertTrue( 110 channelProvider.isHeaderSent( 111 ApiClientHeaderProvider.getDefaultApiClientHeaderKey(), 112 GaxGrpcProperties.getDefaultApiClientHeaderPattern())); 113 } 114 115 @Test predictExceptionTest()116 public void predictExceptionTest() throws Exception { 117 StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT); 118 mockPredictionService.addException(exception); 119 120 try { 121 ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]"); 122 ExamplePayload payload = ExamplePayload.newBuilder().build(); 123 Map<String, String> params = new HashMap<>(); 124 client.predict(name, payload, params); 125 Assert.fail("No exception raised"); 126 } catch (InvalidArgumentException e) { 127 // Expected exception. 128 } 129 } 130 131 @Test predictTest2()132 public void predictTest2() throws Exception { 133 PredictResponse expectedResponse = 134 PredictResponse.newBuilder() 135 .addAllPayload(new ArrayList<AnnotationPayload>()) 136 .setPreprocessedInput(ExamplePayload.newBuilder().build()) 137 .putAllMetadata(new HashMap<String, String>()) 138 .build(); 139 mockPredictionService.addResponse(expectedResponse); 140 141 String name = "name3373707"; 142 ExamplePayload payload = ExamplePayload.newBuilder().build(); 143 Map<String, String> params = new HashMap<>(); 144 145 PredictResponse actualResponse = client.predict(name, payload, params); 146 Assert.assertEquals(expectedResponse, actualResponse); 147 148 List<AbstractMessage> actualRequests = mockPredictionService.getRequests(); 149 Assert.assertEquals(1, actualRequests.size()); 150 PredictRequest actualRequest = ((PredictRequest) actualRequests.get(0)); 151 152 Assert.assertEquals(name, actualRequest.getName()); 153 Assert.assertEquals(payload, actualRequest.getPayload()); 154 Assert.assertEquals(params, actualRequest.getParamsMap()); 155 Assert.assertTrue( 156 channelProvider.isHeaderSent( 157 ApiClientHeaderProvider.getDefaultApiClientHeaderKey(), 158 GaxGrpcProperties.getDefaultApiClientHeaderPattern())); 159 } 160 161 @Test predictExceptionTest2()162 public void predictExceptionTest2() throws Exception { 163 StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT); 164 mockPredictionService.addException(exception); 165 166 try { 167 String name = "name3373707"; 168 ExamplePayload payload = ExamplePayload.newBuilder().build(); 169 Map<String, String> params = new HashMap<>(); 170 client.predict(name, payload, params); 171 Assert.fail("No exception raised"); 172 } catch (InvalidArgumentException e) { 173 // Expected exception. 174 } 175 } 176 177 @Test batchPredictTest()178 public void batchPredictTest() throws Exception { 179 BatchPredictResult expectedResponse = 180 BatchPredictResult.newBuilder().putAllMetadata(new HashMap<String, String>()).build(); 181 Operation resultOperation = 182 Operation.newBuilder() 183 .setName("batchPredictTest") 184 .setDone(true) 185 .setResponse(Any.pack(expectedResponse)) 186 .build(); 187 mockPredictionService.addResponse(resultOperation); 188 189 ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]"); 190 BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build(); 191 BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build(); 192 Map<String, String> params = new HashMap<>(); 193 194 BatchPredictResult actualResponse = 195 client.batchPredictAsync(name, inputConfig, outputConfig, params).get(); 196 Assert.assertEquals(expectedResponse, actualResponse); 197 198 List<AbstractMessage> actualRequests = mockPredictionService.getRequests(); 199 Assert.assertEquals(1, actualRequests.size()); 200 BatchPredictRequest actualRequest = ((BatchPredictRequest) actualRequests.get(0)); 201 202 Assert.assertEquals(name.toString(), actualRequest.getName()); 203 Assert.assertEquals(inputConfig, actualRequest.getInputConfig()); 204 Assert.assertEquals(outputConfig, actualRequest.getOutputConfig()); 205 Assert.assertEquals(params, actualRequest.getParamsMap()); 206 Assert.assertTrue( 207 channelProvider.isHeaderSent( 208 ApiClientHeaderProvider.getDefaultApiClientHeaderKey(), 209 GaxGrpcProperties.getDefaultApiClientHeaderPattern())); 210 } 211 212 @Test batchPredictExceptionTest()213 public void batchPredictExceptionTest() throws Exception { 214 StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT); 215 mockPredictionService.addException(exception); 216 217 try { 218 ModelName name = ModelName.of("[PROJECT]", "[LOCATION]", "[MODEL]"); 219 BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build(); 220 BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build(); 221 Map<String, String> params = new HashMap<>(); 222 client.batchPredictAsync(name, inputConfig, outputConfig, params).get(); 223 Assert.fail("No exception raised"); 224 } catch (ExecutionException e) { 225 Assert.assertEquals(InvalidArgumentException.class, e.getCause().getClass()); 226 InvalidArgumentException apiException = ((InvalidArgumentException) e.getCause()); 227 Assert.assertEquals(StatusCode.Code.INVALID_ARGUMENT, apiException.getStatusCode().getCode()); 228 } 229 } 230 231 @Test batchPredictTest2()232 public void batchPredictTest2() throws Exception { 233 BatchPredictResult expectedResponse = 234 BatchPredictResult.newBuilder().putAllMetadata(new HashMap<String, String>()).build(); 235 Operation resultOperation = 236 Operation.newBuilder() 237 .setName("batchPredictTest") 238 .setDone(true) 239 .setResponse(Any.pack(expectedResponse)) 240 .build(); 241 mockPredictionService.addResponse(resultOperation); 242 243 String name = "name3373707"; 244 BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build(); 245 BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build(); 246 Map<String, String> params = new HashMap<>(); 247 248 BatchPredictResult actualResponse = 249 client.batchPredictAsync(name, inputConfig, outputConfig, params).get(); 250 Assert.assertEquals(expectedResponse, actualResponse); 251 252 List<AbstractMessage> actualRequests = mockPredictionService.getRequests(); 253 Assert.assertEquals(1, actualRequests.size()); 254 BatchPredictRequest actualRequest = ((BatchPredictRequest) actualRequests.get(0)); 255 256 Assert.assertEquals(name, actualRequest.getName()); 257 Assert.assertEquals(inputConfig, actualRequest.getInputConfig()); 258 Assert.assertEquals(outputConfig, actualRequest.getOutputConfig()); 259 Assert.assertEquals(params, actualRequest.getParamsMap()); 260 Assert.assertTrue( 261 channelProvider.isHeaderSent( 262 ApiClientHeaderProvider.getDefaultApiClientHeaderKey(), 263 GaxGrpcProperties.getDefaultApiClientHeaderPattern())); 264 } 265 266 @Test batchPredictExceptionTest2()267 public void batchPredictExceptionTest2() throws Exception { 268 StatusRuntimeException exception = new StatusRuntimeException(io.grpc.Status.INVALID_ARGUMENT); 269 mockPredictionService.addException(exception); 270 271 try { 272 String name = "name3373707"; 273 BatchPredictInputConfig inputConfig = BatchPredictInputConfig.newBuilder().build(); 274 BatchPredictOutputConfig outputConfig = BatchPredictOutputConfig.newBuilder().build(); 275 Map<String, String> params = new HashMap<>(); 276 client.batchPredictAsync(name, inputConfig, outputConfig, params).get(); 277 Assert.fail("No exception raised"); 278 } catch (ExecutionException e) { 279 Assert.assertEquals(InvalidArgumentException.class, e.getCause().getClass()); 280 InvalidArgumentException apiException = ((InvalidArgumentException) e.getCause()); 281 Assert.assertEquals(StatusCode.Code.INVALID_ARGUMENT, apiException.getStatusCode().getCode()); 282 } 283 } 284 } 285