• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 package dev.pigweed.pw_transfer;
16 
17 import com.google.common.util.concurrent.ListenableFuture;
18 import dev.pigweed.pw_log.Logger;
19 import dev.pigweed.pw_rpc.Call;
20 import dev.pigweed.pw_rpc.ChannelOutputException;
21 import dev.pigweed.pw_rpc.MethodClient;
22 import dev.pigweed.pw_rpc.Status;
23 import dev.pigweed.pw_rpc.StreamObserver;
24 import java.time.Instant;
25 import java.time.temporal.ChronoUnit;
26 import java.time.temporal.TemporalUnit;
27 import java.util.Comparator;
28 import java.util.HashMap;
29 import java.util.Map;
30 import java.util.Optional;
31 import java.util.concurrent.BlockingQueue;
32 import java.util.concurrent.LinkedBlockingQueue;
33 import java.util.concurrent.Semaphore;
34 import java.util.concurrent.TimeUnit;
35 import java.util.function.BooleanSupplier;
36 import java.util.function.Consumer;
37 import javax.annotation.Nullable;
38 
39 /** Manages the active transfers and dispatches events to them. */
40 class TransferEventHandler {
41   private static final Logger logger = Logger.forClass(TransferEventHandler.class);
42 
43   // Instant and BlockingQueue use different time unit types.
44   private static final TemporalUnit TIME_UNIT = ChronoUnit.MICROS;
45   private static final TimeUnit POLL_TIME_UNIT = TimeUnit.MICROSECONDS;
46 
47   private final MethodClient readMethod;
48   private final MethodClient writeMethod;
49 
50   private final BlockingQueue<Event> events = new LinkedBlockingQueue<>();
51 
52   // Map session ID to transfer.
53   private final Map<Integer, Transfer<?>> sessionIdToTransfer = new HashMap<>();
54   // Legacy transfers only use the resource ID. The client assigns an arbitrary session ID that
55   // legacy servers ignore. The client then maps from the legacy ID to its local session ID.
56   private final Map<Integer, Integer> legacyIdToSessionId = new HashMap<>();
57 
58   @Nullable private Call.ClientStreaming<Chunk> readStream = null;
59   @Nullable private Call.ClientStreaming<Chunk> writeStream = null;
60   private boolean processEvents = true;
61 
62   private int nextSessionId = 1;
63 
TransferEventHandler(MethodClient readMethod, MethodClient writeMethod)64   TransferEventHandler(MethodClient readMethod, MethodClient writeMethod) {
65     this.readMethod = readMethod;
66     this.writeMethod = writeMethod;
67   }
68 
startWriteTransferAsClient(int resourceId, ProtocolVersion desiredProtocolVersion, TransferTimeoutSettings settings, byte[] data, Consumer<TransferProgress> progressCallback, BooleanSupplier shouldAbortCallback, int initialOffset)69   ListenableFuture<Void> startWriteTransferAsClient(int resourceId,
70       ProtocolVersion desiredProtocolVersion,
71       TransferTimeoutSettings settings,
72       byte[] data,
73       Consumer<TransferProgress> progressCallback,
74       BooleanSupplier shouldAbortCallback,
75       int initialOffset) {
76     WriteTransfer transfer = new WriteTransfer(
77         resourceId, assignSessionId(), desiredProtocolVersion, new TransferInterface() {
78           @Override
79           Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException {
80             if (writeStream == null) {
81               writeStream = writeMethod.invokeBidirectionalStreaming(new ChunkHandler() {
82                 @Override
83                 void resetStream() {
84                   writeStream = null;
85                 }
86               });
87             }
88             return writeStream;
89           }
90         }, settings, data, progressCallback, shouldAbortCallback, initialOffset);
91     startTransferAsClient(transfer);
92     return transfer;
93   }
94 
startReadTransferAsClient(int resourceId, ProtocolVersion desiredProtocolVersion, TransferTimeoutSettings settings, TransferParameters parameters, Consumer<TransferProgress> progressCallback, BooleanSupplier shouldAbortCallback, int initialOffset)95   ListenableFuture<byte[]> startReadTransferAsClient(int resourceId,
96       ProtocolVersion desiredProtocolVersion,
97       TransferTimeoutSettings settings,
98       TransferParameters parameters,
99       Consumer<TransferProgress> progressCallback,
100       BooleanSupplier shouldAbortCallback,
101       int initialOffset) {
102     ReadTransfer transfer = new ReadTransfer(
103         resourceId, assignSessionId(), desiredProtocolVersion, new TransferInterface() {
104           @Override
105           Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException {
106             if (readStream == null) {
107               readStream = readMethod.invokeBidirectionalStreaming(new ChunkHandler() {
108                 @Override
109                 void resetStream() {
110                   readStream = null;
111                 }
112               });
113             }
114             return readStream;
115           }
116         }, settings, parameters, progressCallback, shouldAbortCallback, initialOffset);
117     startTransferAsClient(transfer);
118     return transfer;
119   }
120 
startTransferAsClient(Transfer<?> transfer)121   private void startTransferAsClient(Transfer<?> transfer) {
122     enqueueEvent(() -> {
123       if (sessionIdToTransfer.containsKey(transfer.getSessionId())) {
124         throw new AssertionError("Duplicate session ID " + transfer.getSessionId());
125       }
126 
127       if (transfer.getDesiredProtocolVersion() == ProtocolVersion.LEGACY
128           && transfer.getOffset() != 0) {
129         throw new AssertionError("Cannot start non-zero offset transfer with legacy version");
130       }
131 
132       // The v2 protocol supports multiple transfers for a single resource. For simplicity while
133       // supporting both protocols, only support a single transfer per resource.
134       if (legacyIdToSessionId.containsKey(transfer.getResourceId())) {
135         transfer.terminate(
136             new TransferError("A transfer for resource ID " + transfer.getResourceId()
137                     + " is already in progress! Only one read/write transfer per resource is "
138                     + "supported at a time",
139                 Status.ALREADY_EXISTS));
140         return;
141       }
142       sessionIdToTransfer.put(transfer.getSessionId(), transfer);
143       legacyIdToSessionId.put(transfer.getResourceId(), transfer.getSessionId());
144       transfer.start();
145     });
146   }
147 
148   /** Handles events until stop() is called. */
run()149   void run() {
150     while (processEvents) {
151       handleNextEvent();
152       handleTimeouts();
153     }
154   }
155 
156   /**
157    * Test version of run() that processes all enqueued events before checking for timeouts.
158    *
159    * Tests that need to time out should process all enqueued events first to prevent flaky failures.
160    * If handling one of several queued packets takes longer than the timeout (which must be short
161    * for a unit test), then the test may fail spuriously.
162    *
163    * This run function is not used outside of tests because processing all incoming packets before
164    * checking for timeouts could delay the transfer client's outgoing write packets if there are
165    * lots of inbound packets. This could delay transfers and cause unnecessary timeouts.
166    */
runForTestsThatMustTimeOut()167   void runForTestsThatMustTimeOut() {
168     while (processEvents) {
169       while (!events.isEmpty()) {
170         handleNextEvent();
171       }
172       handleTimeouts();
173     }
174   }
175 
176   /** Stops the transfer event handler from processing events. */
stop()177   void stop() {
178     enqueueEvent(() -> {
179       logger.atFine().log("Terminating TransferEventHandler");
180       sessionIdToTransfer.values().forEach(Transfer::handleTermination);
181       processEvents = false;
182     });
183   }
184 
185   /** Blocks until all events currently in the queue are processed; for test use only. */
waitUntilEventsAreProcessedForTest()186   void waitUntilEventsAreProcessedForTest() {
187     Semaphore semaphore = new Semaphore(0);
188     enqueueEvent(semaphore::release);
189     try {
190       semaphore.acquire();
191     } catch (InterruptedException e) {
192       throw new AssertionError("Unexpectedly interrupted", e);
193     }
194   }
195 
196   /** Generates the session ID to use for the next transfer. */
assignSessionId()197   private int assignSessionId() {
198     return nextSessionId++;
199   }
200 
201   /** Returns the session ID that will be used for the next transfer. */
getNextSessionIdForTest()202   final int getNextSessionIdForTest() {
203     return nextSessionId;
204   }
205 
enqueueEvent(Event event)206   private void enqueueEvent(Event event) {
207     while (true) {
208       try {
209         events.put(event);
210         return;
211       } catch (InterruptedException e) {
212         // Ignore and keep trying.
213       }
214     }
215   }
216 
handleNextEvent()217   private void handleNextEvent() {
218     final long sleepFor = TIME_UNIT.between(Instant.now(), getNextTimeout());
219     try {
220       Event event = events.poll(sleepFor, POLL_TIME_UNIT);
221       if (event != null) {
222         event.handle();
223       }
224     } catch (InterruptedException e) {
225       // If interrupted, continue around the loop.
226     }
227   }
228 
handleTimeouts()229   private void handleTimeouts() {
230     // Copy to array since transfers may remove themselves from sessionIdToTransfer while iterating.
231     for (Transfer<?> transfer : sessionIdToTransfer.values().toArray(Transfer<?>[] ::new)) {
232       transfer.handleTimeoutIfDeadlineExceeded();
233     }
234   }
235 
getNextTimeout()236   private Instant getNextTimeout() {
237     Optional<Transfer<?>> transfer =
238         sessionIdToTransfer.values().stream().min(Comparator.comparing(Transfer::getDeadline));
239     return transfer.isPresent() ? transfer.get().getDeadline() : Transfer.NO_TIMEOUT;
240   }
241 
242   /** This interface gives a Transfer access to the TransferEventHandler. */
243   abstract class TransferInterface {
TransferInterface()244     private TransferInterface() {}
245 
246     /**
247      *  Sends the provided transfer chunk.
248      *
249      *  Must be called on the transfer thread.
250      */
sendChunk(Chunk chunk)251     void sendChunk(Chunk chunk) throws TransferError {
252       try {
253         getStream().write(chunk);
254       } catch (ChannelOutputException e) {
255         throw new TransferError("Failed to send chunk for write transfer", e);
256       }
257     }
258 
259     /**
260      *  Removes this transfer from the list of active transfers.
261      *
262      *  Must be called on the transfer thread.
263      */
264     // TODO(frolv): Investigate why this is occurring -- there shouldn't be any
265     // futures here.
266     @SuppressWarnings("FutureReturnValueIgnored")
unregisterTransfer(Transfer<?> transfer)267     void unregisterTransfer(Transfer<?> transfer) {
268       sessionIdToTransfer.remove(transfer.getSessionId());
269       legacyIdToSessionId.remove(transfer.getResourceId());
270     }
271 
272     /**
273      * Initiates the cancellation process for the provided transfer.
274      *
275      * May be called from any thread.
276      */
cancelTransfer(Transfer<?> transfer)277     void cancelTransfer(Transfer<?> transfer) {
278       enqueueEvent(transfer::handleCancellation);
279     }
280 
281     /** Gets either the read or write stream. */
getStream()282     abstract Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException;
283   }
284 
285   /** Handles responses on the pw_transfer RPCs. */
286   private abstract class ChunkHandler implements StreamObserver<Chunk> {
287     @Override
onNext(Chunk chunkProto)288     public final void onNext(Chunk chunkProto) {
289       VersionedChunk chunk = VersionedChunk.fromMessage(chunkProto, legacyIdToSessionId);
290 
291       enqueueEvent(() -> {
292         Transfer<?> transfer;
293         if (chunk.sessionId() == VersionedChunk.UNKNOWN_SESSION_ID
294             || (transfer = sessionIdToTransfer.get(chunk.sessionId())) == null) {
295           logger.atInfo().log("Ignoring unrecognized transfer chunk: %s", chunk);
296           return;
297         }
298 
299         logger.atFinest().log("%s received chunk: %s", transfer, chunk);
300         transfer.handleChunk(chunk);
301       });
302     }
303 
304     @Override
onCompleted(Status status)305     public final void onCompleted(Status status) {
306       onError(Status.INTERNAL); // This RPC should never complete: treat as an internal error.
307     }
308 
309     @Override
onError(Status status)310     public final void onError(Status status) {
311       enqueueEvent(() -> {
312         resetStream();
313 
314         TransferError error = new TransferError(
315             "Transfer stream RPC closed unexpectedly with status " + status, Status.INTERNAL);
316 
317         // The transfers remove themselves from the Map during cleanup; iterate over a copied list.
318         for (Transfer<?> transfer : sessionIdToTransfer.values().toArray(Transfer<?>[] ::new)) {
319           transfer.terminate(error);
320         }
321       });
322     }
323 
resetStream()324     abstract void resetStream();
325   }
326 
327   // Represents an event that occurs during a transfer
328   private interface Event {
handle()329     void handle();
330   }
331 }
332