• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 The Pigweed Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy of
5// the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14
15import { Status } from 'pigweedjs/pw_status';
16import { Message } from 'google-protobuf';
17
18import WaitQueue from './queue';
19
20import { PendingCalls, Rpc } from './rpc_classes';
21
22export type Callback = (a: any) => any;
23
24class RpcError extends Error {
25  status: Status;
26
27  constructor(rpc: Rpc, status: Status) {
28    let message = '';
29    if (status === Status.NOT_FOUND) {
30      message = ': the RPC server does not support this RPC';
31    } else if (status === Status.DATA_LOSS) {
32      message = ': an error occurred while decoding the RPC payload';
33    }
34
35    super(`${rpc.method.name} failed with error ${Status[status]}${message}`);
36    this.status = status;
37  }
38}
39
40class RpcTimeout extends Error {
41  readonly rpc: Rpc;
42  readonly timeoutMs: number;
43
44  constructor(rpc: Rpc, timeoutMs: number) {
45    super(`${rpc.method.name} timed out after ${timeoutMs} ms`);
46    this.rpc = rpc;
47    this.timeoutMs = timeoutMs;
48  }
49}
50
51/** Represent an in-progress or completed RPC call. */
52export class Call {
53  // Responses ordered by arrival time. Undefined signifies stream completion.
54  private responseQueue = new WaitQueue<Message | undefined>();
55  protected responses: Message[] = [];
56
57  private rpcs: PendingCalls;
58  rpc: Rpc;
59  readonly callId: number;
60
61  private onNext: Callback;
62  private onCompleted: Callback;
63  private onError: Callback;
64
65  status?: Status;
66  error?: Status;
67  callbackException?: Error;
68
69  constructor(
70    rpcs: PendingCalls,
71    rpc: Rpc,
72    onNext: Callback,
73    onCompleted: Callback,
74    onError: Callback,
75  ) {
76    this.rpcs = rpcs;
77    this.rpc = rpc;
78
79    this.onNext = onNext;
80    this.onCompleted = onCompleted;
81    this.onError = onError;
82    this.callId = rpcs.allocateCallId();
83  }
84
85  /* Calls the RPC. This must be called immediately after construction. */
86  invoke(request?: Message, ignoreErrors = false): void {
87    const previous = this.rpcs.sendRequest(
88      this.rpc,
89      this,
90      ignoreErrors,
91      request,
92    );
93
94    if (previous !== undefined && !previous.completed) {
95      previous.handleError(Status.CANCELLED);
96    }
97  }
98
99  get completed(): boolean {
100    return this.status !== undefined || this.error !== undefined;
101  }
102
103  // eslint-disable-next-line @typescript-eslint/ban-types
104  private invokeCallback(func: () => {}) {
105    try {
106      func();
107    } catch (err: unknown) {
108      if (err instanceof Error) {
109        console.error(
110          `An exception was raised while invoking a callback: ${err}`,
111        );
112        this.callbackException = err;
113      }
114      console.error(`Unexpected item thrown while invoking callback: ${err}`);
115    }
116  }
117
118  handleResponse(response: Message): void {
119    this.responses.push(response);
120    this.responseQueue.push(response);
121    this.invokeCallback(() => this.onNext(response));
122  }
123
124  handleCompletion(status: Status) {
125    this.status = status;
126    this.responseQueue.push(undefined);
127    this.invokeCallback(() => this.onCompleted(status));
128  }
129
130  handleError(error: Status): void {
131    this.error = error;
132    this.responseQueue.push(undefined);
133    this.invokeCallback(() => this.onError(error));
134  }
135
136  private async queuePopWithTimeout(
137    timeoutMs: number,
138  ): Promise<Message | undefined> {
139    // eslint-disable-next-line no-async-promise-executor
140    return new Promise(async (resolve, reject) => {
141      let timeoutExpired = false;
142      const timeoutWatcher = setTimeout(() => {
143        timeoutExpired = true;
144        reject(new RpcTimeout(this.rpc, timeoutMs));
145      }, timeoutMs);
146      const response = await this.responseQueue.shift();
147      if (timeoutExpired) {
148        this.responseQueue.unshift(response);
149        return;
150      }
151      clearTimeout(timeoutWatcher);
152      resolve(response);
153    });
154  }
155
156  /**
157   * Yields responses up the specified count as they are added.
158   *
159   * Throws an error as soon as it is received even if there are still
160   * responses in the queue.
161   *
162   * Usage
163   * ```
164   * for await (const response of call.getResponses(5)) {
165   *  console.log(response);
166   * }
167   * ```
168   *
169   * @param {number} count The number of responses to read before returning.
170   *    If no value is specified, getResponses will block until the stream
171   *    either ends or hits an error.
172   * @param {number} timeout The number of milliseconds to wait for a response
173   *    before throwing an error.
174   */
175  async *getResponses(
176    count?: number,
177    timeoutMs?: number,
178  ): AsyncGenerator<Message> {
179    this.checkErrors();
180
181    if (this.completed && this.responseQueue.length == 0) {
182      return;
183    }
184
185    let remaining = count ?? Number.POSITIVE_INFINITY;
186    while (remaining > 0) {
187      const response =
188        timeoutMs === undefined
189          ? await this.responseQueue.shift()
190          : await this.queuePopWithTimeout(timeoutMs!);
191      this.checkErrors();
192      if (response === undefined) {
193        return;
194      }
195      yield response!;
196      remaining -= 1;
197    }
198  }
199
200  cancel(): boolean {
201    if (this.completed) {
202      return false;
203    }
204
205    this.error = Status.CANCELLED;
206    return this.rpcs.sendCancel(this.rpc, this.callId);
207  }
208
209  private checkErrors(): void {
210    if (this.callbackException !== undefined) {
211      throw this.callbackException;
212    }
213    if (this.error !== undefined) {
214      throw new RpcError(this.rpc, this.error);
215    }
216  }
217
218  protected async unaryWait(timeoutMs?: number): Promise<[Status, Message]> {
219    for await (const response of this.getResponses(1, timeoutMs)) {
220      // Do nothing.
221    }
222    if (this.status === undefined) {
223      throw Error('Unexpected undefined status at end of stream');
224    }
225    if (this.responses.length !== 1) {
226      throw Error(`Unexpected number of responses: ${this.responses.length}`);
227    }
228    return [this.status!, this.responses[0]];
229  }
230
231  protected async streamWait(timeoutMs?: number): Promise<[Status, Message[]]> {
232    for await (const response of this.getResponses(undefined, timeoutMs)) {
233      // Do nothing.
234    }
235    if (this.status === undefined) {
236      throw Error('Unexpected undefined status at end of stream');
237    }
238    return [this.status!, this.responses];
239  }
240
241  protected sendClientStream(request: Message) {
242    this.checkErrors();
243    if (this.status !== undefined) {
244      throw new RpcError(this.rpc, Status.FAILED_PRECONDITION);
245    }
246    this.rpcs.sendClientStream(this.rpc, request, this.callId);
247  }
248
249  protected finishClientStream(requests: Message[]) {
250    for (const request of requests) {
251      this.sendClientStream(request);
252    }
253
254    if (!this.completed) {
255      this.rpcs.sendClientStreamEnd(this.rpc, this.callId);
256    }
257  }
258}
259
260/** Tracks the state of a unary RPC call. */
261export class UnaryCall extends Call {
262  /** Awaits the server response */
263  async complete(timeoutMs?: number): Promise<[Status, Message]> {
264    return await this.unaryWait(timeoutMs);
265  }
266}
267
268/** Tracks the state of a client streaming RPC call. */
269export class ClientStreamingCall extends Call {
270  /** Gets the last server message, if it exists */
271  get response(): Message | undefined {
272    return this.responses.length > 0
273      ? this.responses[this.responses.length - 1]
274      : undefined;
275  }
276
277  /** Sends a message from the client. */
278  send(request: Message) {
279    this.sendClientStream(request);
280  }
281
282  /** Ends the client stream and waits for the RPC to complete. */
283  async finishAndWait(
284    requests: Message[] = [],
285    timeoutMs?: number,
286  ): Promise<[Status, Message[]]> {
287    this.finishClientStream(requests);
288    return await this.streamWait(timeoutMs);
289  }
290}
291
292/** Tracks the state of a server streaming RPC call. */
293export class ServerStreamingCall extends Call {
294  complete(timeoutMs?: number): Promise<[Status, Message[]]> {
295    return this.streamWait(timeoutMs);
296  }
297}
298
299/** Tracks the state of a bidirectional streaming RPC call. */
300export class BidirectionalStreamingCall extends Call {
301  /** Sends a message from the client. */
302  send(request: Message) {
303    this.sendClientStream(request);
304  }
305
306  /** Ends the client stream and waits for the RPC to complete. */
307  async finishAndWait(
308    requests: Array<Message> = [],
309    timeoutMs?: number,
310  ): Promise<[Status, Array<Message>]> {
311    this.finishClientStream(requests);
312    return await this.streamWait(timeoutMs);
313  }
314}
315