• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_
17 
18 #include <deque>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/base_collective_executor.h"
24 #include "tensorflow/core/framework/collective.h"
25 
26 namespace tensorflow {
27 class Device;
28 
29 // Basic ring-algorithm implementation to be further specialized
30 // for specific collective functions.
31 class RingAlg : public CollectiveImplementationInterface {
32  public:
33   explicit RingAlg(CollectiveType type, const string& name);
~RingAlg()34   ~RingAlg() override {}
35 
36   // Establishes the requested number of subdivision permutations based on the
37   // ring order implicit in the device order.
38   Status InitializeCollectiveParams(CollectiveParams* col_params) override;
39 
40   // Initializes members of CollectiveContext not yet initialized, i.e. device
41   // and device_locality.  Also saves the CollectiveContext in this object.
42   Status InitializeCollectiveContext(
43       std::shared_ptr<CollectiveContext> col_ctx) override;
44 
45   // No-op for ring alg.
InitializeCollectiveGroupRuntimeDetails(CollGroupRuntimeDetails *)46   Status InitializeCollectiveGroupRuntimeDetails(
47       CollGroupRuntimeDetails*) override {
48     return Status::OK();
49   }
50 
51  protected:
52   // Called when a bad status is received that implies we should terminate
53   // execution and return a bad status.
54   void StartAbort(const Status& s);
55   void Finish(bool ok);
56 
57   // Current status of a RingField
58   enum RingFieldAction {
59     RF_INIT = 0,    // Just initialized for a pass
60     RF_RECV,        // Recv pending
61     RF_REDUCE,      // Reduce pending
62     RF_FINALIZE,    // FinalOp pending
63     RF_SEND_READY,  // Ready to send
64     RF_SEND,        // Send pending
65     RF_DONE,        // No more work
66   };
67 
68   // Tracks progress of actions on a single subfield of the entire tensor.
69   struct RingField {
70     int16 chunk_idx;     // major division index
71     int16 subdiv_idx;    // minor division index
72     int16 sc_idx;        // subchunk index
73     int16 rank;          // rank within subdiv permutation
74     int16 recv_dev_idx;  // dev from which value should be recv'd
75     RingFieldAction action;
76     bool second_pass;
77     bool recv_is_remote = false;
78     bool send_is_remote = false;
79     bool do_send = false;   // is the value sent in this pass?
80     bool do_recv = false;   // is the value recv'd in this pass?
81     bool is_final = false;  // is the last field in the pass for this rank
82     Tensor chunk;           // alias to field values
83     Tensor tmp_chunk;
84     Status status;
85     string DebugString() const;
86   };
87   virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
88                              int field_idx);
89   void AdvanceToSecondPass(RingField* rf);
90   void DispatchSend(RingField* rf, const StatusCallback& done);
91   void DispatchRecv(RingField* rf, const StatusCallback& done);
92 
93   // For constructing log messages for debugging.
94   string FieldState();
95   string TensorDebugString(const Tensor& tensor);
96 
97   // Producer/Consumer Queue of RingField structs.
98   class PCQueue {
99    public:
100     void Enqueue(RingField* rf);
101     RingField* Dequeue();
102 
103    private:
104     mutex pcq_mu_;
105     condition_variable cv_;
106     int waiter_count_ TF_GUARDED_BY(pcq_mu_) = 0;
107     std::deque<RingField*> deque_ TF_GUARDED_BY(pcq_mu_);
108   };
109 
110   const CollectiveType type_;
111   const string name_;
112   std::shared_ptr<CollectiveContext> col_ctx_;
113   const CollectiveParams* col_params_;  // Not owned
114   StatusCallback done_;
115   int group_size_;
116   int num_subdivs_;
117   Tensor group_size_tensor_;
118   Notification group_size_tensor_ready_;
119   std::unique_ptr<CollectiveAdapter> ca_;
120   mutex status_mu_;
121   Status status_ TF_GUARDED_BY(status_mu_);
122   std::vector<RingField> rfv_;
123 };
124 
125 }  // namespace tensorflow
126 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_
127