1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ 17 18 #include <deque> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/common_runtime/base_collective_executor.h" 24 #include "tensorflow/core/framework/collective.h" 25 26 namespace tensorflow { 27 class Device; 28 29 // Basic ring-algorithm implementation to be further specialized 30 // for specific collective functions. 31 class RingAlg : public CollectiveImplementationInterface { 32 public: 33 explicit RingAlg(CollectiveType type, const string& name); ~RingAlg()34 ~RingAlg() override {} 35 36 // Establishes the requested number of subdivision permutations based on the 37 // ring order implicit in the device order. 38 Status InitializeCollectiveParams(CollectiveParams* col_params) override; 39 40 // Initializes members of CollectiveContext not yet initialized, i.e. device 41 // and device_locality. Also saves the CollectiveContext in this object. 42 Status InitializeCollectiveContext(CollectiveContext* col_ctx) override; 43 44 // No-op for ring alg. InitializeInstanceBeforeGroupDiscovery(CollectiveParams *)45 Status InitializeInstanceBeforeGroupDiscovery(CollectiveParams*) override { 46 return Status::OK(); 47 } 48 49 protected: 50 // Called when a bad status is received that implies we should terminate 51 // execution and return a bad status. 52 void StartAbort(const Status& s); 53 void Finish(bool ok); 54 55 // Current status of a RingField 56 enum RingFieldAction { 57 RF_INIT = 0, // Just initialized for a pass 58 RF_RECV, // Recv pending 59 RF_REDUCE, // Reduce pending 60 RF_FINALIZE, // FinalOp pending 61 RF_SEND_READY, // Ready to send 62 RF_SEND, // Send pending 63 RF_DONE, // No more work 64 }; 65 66 // Tracks progress of actions on a single subfield of the entire tensor. 67 struct RingField { 68 int16 chunk_idx; // major division index 69 int16 subdiv_idx; // minor division index 70 int16 sc_idx; // subchunk index 71 int16 rank; // rank within subdiv permutation 72 int16 recv_dev_idx; // dev from which value should be recv'd 73 RingFieldAction action; 74 bool second_pass; 75 bool recv_is_remote = false; 76 bool send_is_remote = false; 77 bool do_send = false; // is the value sent in this pass? 78 bool do_recv = false; // is the value recv'd in this pass? 79 bool is_final = false; // is the last field in the pass for this rank 80 Tensor chunk; // alias to field values 81 Tensor tmp_chunk; 82 Status status; 83 string DebugString() const; 84 }; 85 virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, 86 int field_idx); 87 void AdvanceToSecondPass(RingField* rf); 88 void DispatchSend(RingField* rf, const StatusCallback& done); 89 void DispatchRecv(RingField* rf, const StatusCallback& done); 90 91 // For constructing log messages for debugging. 92 string FieldState(); 93 string TensorDebugString(const Tensor& tensor); 94 95 // Producer/Consumer Queue of RingField structs. 96 class PCQueue { 97 public: 98 void Enqueue(RingField* rf); 99 RingField* Dequeue(); 100 101 private: 102 mutex pcq_mu_; 103 condition_variable cv_; 104 int waiter_count_ GUARDED_BY(pcq_mu_) = 0; 105 std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_); 106 }; 107 108 const CollectiveType type_; 109 const string name_; 110 CollectiveContext* col_ctx_; // Not owned 111 const CollectiveParams* col_params_; // Not owned 112 StatusCallback done_; 113 int group_size_; 114 int num_subdivs_; 115 Tensor group_size_tensor_; 116 Notification group_size_tensor_ready_; 117 std::unique_ptr<CollectiveAdapter> ca_; 118 mutex status_mu_; 119 Status status_ GUARDED_BY(status_mu_); 120 std::vector<RingField> rfv_; 121 }; 122 123 } // namespace tensorflow 124 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ 125