1// Copyright 2021 The Pigweed Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may not 4// use this file except in compliance with the License. You may obtain a copy of 5// the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12// License for the specific language governing permissions and limitations under 13// the License. 14 15import { Status } from 'pigweedjs/pw_status'; 16import { Message } from 'google-protobuf'; 17 18import WaitQueue from './queue'; 19 20import { PendingCalls, Rpc } from './rpc_classes'; 21 22export type Callback = (a: any) => any; 23 24class RpcError extends Error { 25 status: Status; 26 27 constructor(rpc: Rpc, status: Status) { 28 let message = ''; 29 if (status === Status.NOT_FOUND) { 30 message = ': the RPC server does not support this RPC'; 31 } else if (status === Status.DATA_LOSS) { 32 message = ': an error occurred while decoding the RPC payload'; 33 } 34 35 super(`${rpc.method.name} failed with error ${Status[status]}${message}`); 36 this.status = status; 37 } 38} 39 40class RpcTimeout extends Error { 41 readonly rpc: Rpc; 42 readonly timeoutMs: number; 43 44 constructor(rpc: Rpc, timeoutMs: number) { 45 super(`${rpc.method.name} timed out after ${timeoutMs} ms`); 46 this.rpc = rpc; 47 this.timeoutMs = timeoutMs; 48 } 49} 50 51/** Represent an in-progress or completed RPC call. */ 52export class Call { 53 // Responses ordered by arrival time. Undefined signifies stream completion. 54 private responseQueue = new WaitQueue<Message | undefined>(); 55 protected responses: Message[] = []; 56 57 private rpcs: PendingCalls; 58 rpc: Rpc; 59 readonly callId: number; 60 61 private onNext: Callback; 62 private onCompleted: Callback; 63 private onError: Callback; 64 65 status?: Status; 66 error?: Status; 67 callbackException?: Error; 68 69 constructor( 70 rpcs: PendingCalls, 71 rpc: Rpc, 72 onNext: Callback, 73 onCompleted: Callback, 74 onError: Callback, 75 ) { 76 this.rpcs = rpcs; 77 this.rpc = rpc; 78 79 this.onNext = onNext; 80 this.onCompleted = onCompleted; 81 this.onError = onError; 82 this.callId = rpcs.allocateCallId(); 83 } 84 85 /* Calls the RPC. This must be called immediately after construction. */ 86 invoke(request?: Message, ignoreErrors = false): void { 87 const previous = this.rpcs.sendRequest( 88 this.rpc, 89 this, 90 ignoreErrors, 91 request, 92 ); 93 94 if (previous !== undefined && !previous.completed) { 95 previous.handleError(Status.CANCELLED); 96 } 97 } 98 99 get completed(): boolean { 100 return this.status !== undefined || this.error !== undefined; 101 } 102 103 // eslint-disable-next-line @typescript-eslint/ban-types 104 private invokeCallback(func: () => {}) { 105 try { 106 func(); 107 } catch (err: unknown) { 108 if (err instanceof Error) { 109 console.error( 110 `An exception was raised while invoking a callback: ${err}`, 111 ); 112 this.callbackException = err; 113 } 114 console.error(`Unexpected item thrown while invoking callback: ${err}`); 115 } 116 } 117 118 handleResponse(response: Message): void { 119 this.responses.push(response); 120 this.responseQueue.push(response); 121 this.invokeCallback(() => this.onNext(response)); 122 } 123 124 handleCompletion(status: Status) { 125 this.status = status; 126 this.responseQueue.push(undefined); 127 this.invokeCallback(() => this.onCompleted(status)); 128 } 129 130 handleError(error: Status): void { 131 this.error = error; 132 this.responseQueue.push(undefined); 133 this.invokeCallback(() => this.onError(error)); 134 } 135 136 private async queuePopWithTimeout( 137 timeoutMs: number, 138 ): Promise<Message | undefined> { 139 // eslint-disable-next-line no-async-promise-executor 140 return new Promise(async (resolve, reject) => { 141 let timeoutExpired = false; 142 const timeoutWatcher = setTimeout(() => { 143 timeoutExpired = true; 144 reject(new RpcTimeout(this.rpc, timeoutMs)); 145 }, timeoutMs); 146 const response = await this.responseQueue.shift(); 147 if (timeoutExpired) { 148 this.responseQueue.unshift(response); 149 return; 150 } 151 clearTimeout(timeoutWatcher); 152 resolve(response); 153 }); 154 } 155 156 /** 157 * Yields responses up the specified count as they are added. 158 * 159 * Throws an error as soon as it is received even if there are still 160 * responses in the queue. 161 * 162 * Usage 163 * ``` 164 * for await (const response of call.getResponses(5)) { 165 * console.log(response); 166 * } 167 * ``` 168 * 169 * @param {number} count The number of responses to read before returning. 170 * If no value is specified, getResponses will block until the stream 171 * either ends or hits an error. 172 * @param {number} timeout The number of milliseconds to wait for a response 173 * before throwing an error. 174 */ 175 async *getResponses( 176 count?: number, 177 timeoutMs?: number, 178 ): AsyncGenerator<Message> { 179 this.checkErrors(); 180 181 if (this.completed && this.responseQueue.length == 0) { 182 return; 183 } 184 185 let remaining = count ?? Number.POSITIVE_INFINITY; 186 while (remaining > 0) { 187 const response = 188 timeoutMs === undefined 189 ? await this.responseQueue.shift() 190 : await this.queuePopWithTimeout(timeoutMs!); 191 this.checkErrors(); 192 if (response === undefined) { 193 return; 194 } 195 yield response!; 196 remaining -= 1; 197 } 198 } 199 200 cancel(): boolean { 201 if (this.completed) { 202 return false; 203 } 204 205 this.error = Status.CANCELLED; 206 return this.rpcs.sendCancel(this.rpc, this.callId); 207 } 208 209 private checkErrors(): void { 210 if (this.callbackException !== undefined) { 211 throw this.callbackException; 212 } 213 if (this.error !== undefined) { 214 throw new RpcError(this.rpc, this.error); 215 } 216 } 217 218 protected async unaryWait(timeoutMs?: number): Promise<[Status, Message]> { 219 for await (const response of this.getResponses(1, timeoutMs)) { 220 // Do nothing. 221 } 222 if (this.status === undefined) { 223 throw Error('Unexpected undefined status at end of stream'); 224 } 225 if (this.responses.length !== 1) { 226 throw Error(`Unexpected number of responses: ${this.responses.length}`); 227 } 228 return [this.status!, this.responses[0]]; 229 } 230 231 protected async streamWait(timeoutMs?: number): Promise<[Status, Message[]]> { 232 for await (const response of this.getResponses(undefined, timeoutMs)) { 233 // Do nothing. 234 } 235 if (this.status === undefined) { 236 throw Error('Unexpected undefined status at end of stream'); 237 } 238 return [this.status!, this.responses]; 239 } 240 241 protected sendClientStream(request: Message) { 242 this.checkErrors(); 243 if (this.status !== undefined) { 244 throw new RpcError(this.rpc, Status.FAILED_PRECONDITION); 245 } 246 this.rpcs.sendClientStream(this.rpc, request, this.callId); 247 } 248 249 protected finishClientStream(requests: Message[]) { 250 for (const request of requests) { 251 this.sendClientStream(request); 252 } 253 254 if (!this.completed) { 255 this.rpcs.sendClientStreamEnd(this.rpc, this.callId); 256 } 257 } 258} 259 260/** Tracks the state of a unary RPC call. */ 261export class UnaryCall extends Call { 262 /** Awaits the server response */ 263 async complete(timeoutMs?: number): Promise<[Status, Message]> { 264 return await this.unaryWait(timeoutMs); 265 } 266} 267 268/** Tracks the state of a client streaming RPC call. */ 269export class ClientStreamingCall extends Call { 270 /** Gets the last server message, if it exists */ 271 get response(): Message | undefined { 272 return this.responses.length > 0 273 ? this.responses[this.responses.length - 1] 274 : undefined; 275 } 276 277 /** Sends a message from the client. */ 278 send(request: Message) { 279 this.sendClientStream(request); 280 } 281 282 /** Ends the client stream and waits for the RPC to complete. */ 283 async finishAndWait( 284 requests: Message[] = [], 285 timeoutMs?: number, 286 ): Promise<[Status, Message[]]> { 287 this.finishClientStream(requests); 288 return await this.streamWait(timeoutMs); 289 } 290} 291 292/** Tracks the state of a server streaming RPC call. */ 293export class ServerStreamingCall extends Call { 294 complete(timeoutMs?: number): Promise<[Status, Message[]]> { 295 return this.streamWait(timeoutMs); 296 } 297} 298 299/** Tracks the state of a bidirectional streaming RPC call. */ 300export class BidirectionalStreamingCall extends Call { 301 /** Sends a message from the client. */ 302 send(request: Message) { 303 this.sendClientStream(request); 304 } 305 306 /** Ends the client stream and waits for the RPC to complete. */ 307 async finishAndWait( 308 requests: Array<Message> = [], 309 timeoutMs?: number, 310 ): Promise<[Status, Array<Message>]> { 311 this.finishClientStream(requests); 312 return await this.streamWait(timeoutMs); 313 } 314} 315