1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ 18 19 #include <memory> 20 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/literal.h" 23 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 24 #include "tensorflow/compiler/xla/shape.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 27 namespace xla { 28 29 class OutfeedReceiverImpl; 30 31 // Implements a multithreaded receiver of outfeeds from devices. 32 class OutfeedReceiver { 33 public: 34 // A callback takes: device, consumer id, received. 35 using Callback = 36 std::function<void(PjRtDevice*, uint32_t, std::shared_ptr<Literal>)>; 37 38 // Constructs the receiver for the given clients and callback function. 39 // 40 // Args: 41 // callback: a function to be called when an outfeed is ready for 42 // processing. 43 // clients: the clients for whose devices to listen. 44 // max_callback_queue_size_bytes: the maximum number of bytes for all 45 // received outfeeds queued to be processed. When this limit is reached 46 // we pause receiving outfeeds from devices. 47 OutfeedReceiver(Callback callback, absl::Span<PjRtClient* const> clients, 48 ssize_t max_callback_queue_size_bytes); 49 50 OutfeedReceiver(const OutfeedReceiver&) = delete; 51 OutfeedReceiver& operator=(const OutfeedReceiver&) = delete; 52 53 // Blocks until all data has been received from devices and all data 54 // in the queue has been passed to Python. 55 ~OutfeedReceiver(); 56 57 // Starts the listener threads and the callback thread. 58 void Start(); 59 60 // Adds to the computation builder the outfeed of the arrays. 61 // Has the side-effect of registering the sent shape for the consumer_id. 62 // Returns error status if the outfeed shape is different than the 63 // previously used shape for the same consumer_id or the consumer id is 64 // invalid. 65 StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, 66 uint32_t consumer_id, 67 std::vector<XlaOp> arrays); 68 69 private: 70 std::unique_ptr<OutfeedReceiverImpl> p_impl_; 71 }; 72 73 } // namespace xla 74 75 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ 76