• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_
18 
19 #include <list>
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
27 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/types.h"
30 
31 namespace xla {
32 namespace gpu {
33 
34 // Encapsulates in which order and on which streams the thunks are executed. A
35 // schedule contains
36 //
37 // 1. A stream assignment indicating which stream each thunk is executed on.
38 //
39 // 2. A total order of all thunks. If A is ordered before B and they are
40 // assigned to the same stream, then A completes before B starts. If A is
41 // ordered before B and they are on different streams, their actual execution
42 // order is not determined.
43 //
44 // 3. A set of dependency edges. If A and B are scheduled on different streams
45 // and A has to complete before B starts (e.g. A produces an input of B), then B
46 // "depends" on A.
47 class ThunkSchedule {
48  public:
49   // `thunk_to_hlo` is an one-to-one map. Every thunk in this container maps to
50   // an HLO, but not every HLO ever exists produces a Thunk.
51   //
52   // thunk_to_hlo.keys() == set(thunks).
53   ThunkSchedule(
54       std::unique_ptr<ThunkSequence> thunks,
55       std::unique_ptr<StreamAssignment> stream_assignment,
56       absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo);
57 
58   // Single stream, trivial schedule in the ThunkSequence order.
59   explicit ThunkSchedule(std::unique_ptr<ThunkSequence> thunks);
60 
61   // Returns the total order of executing all the thunks.
TotalOrder()62   const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
63 
64   // Thunks that `thunk` depends on.
65   const std::list<const Thunk*>& DependsOn(const Thunk* thunk) const;
66   // Whether `thunk` is depended by another thunk.
Depended(const Thunk * thunk)67   bool Depended(const Thunk* thunk) const {
68     return depended_by_.contains(thunk);
69   }
70 
71   // Delegates to StreamAssignment.
StreamCount()72   int StreamCount() const {
73     if (stream_assignment_) {
74       return stream_assignment_->StreamCount();
75     }
76     return 1;
77   }
StreamNumberForThunk(const Thunk * thunk)78   int StreamNumberForThunk(const Thunk* thunk) const {
79     if (stream_assignment_) {
80       return stream_assignment_->StreamNumberForHlo(*thunk_to_hlo_.at(thunk));
81     }
82     return 0;
83   }
84 
85   string ToString() const;
86 
87  private:
88   void RemoveRedundantDependencyEdges();
89 
90   // Adds `operand` and its transitive operands to the dependency list of
91   // `thunk`.
92   //
93   // Precondition: `operand` is a non-trivial (i.e. excluding
94   // thunk.hlo_instruction_ itself) transitive operand of
95   // thunk.hlo_instruction_.
96   void AddDependenciesOnTransitiveOperands(
97       const Thunk& thunk, const HloInstruction& operand,
98       const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk);
99 
100   std::unique_ptr<ThunkSequence> thunks_;
101   std::vector<Thunk*> thunk_total_order_;
102 
103   absl::flat_hash_map<const Thunk*, std::list<const Thunk*>> depends_on_;
104   absl::flat_hash_set<const Thunk*> depended_by_;
105   std::list<const Thunk*> empty_thunk_list_;
106 
107   std::unique_ptr<StreamAssignment> stream_assignment_;
108 
109   absl::flat_hash_map<const Thunk*, const HloInstruction*> thunk_to_hlo_;
110 };
111 
112 }  // namespace gpu
113 }  // namespace xla
114 
115 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_
116