• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_breadth_assignment.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <set>
21 #include <vector>
22 
23 #include "absl/status/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/memory_management/internal.h"
25 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
26 
27 namespace tflite {
28 namespace gpu {
29 namespace {
30 
31 // Set of usage records for all tensors assigned to the shared object, ordered
32 // by first_task.
33 using SharedObjectSchedule = std::set<TensorUsageRecord<size_t>>;
34 
35 struct TaskBreadthWithId {
36   size_t breadth;
37   TaskId task_id;
38 
TaskBreadthWithIdtflite::gpu::__anon9268cc0c0111::TaskBreadthWithId39   TaskBreadthWithId(size_t breadth, size_t task_id)
40       : breadth(breadth), task_id(task_id) {}
41 
42   // Default order of TaskBreadthWithId is increasing order of their breadth.
operator <tflite::gpu::__anon9268cc0c0111::TaskBreadthWithId43   bool operator<(const TaskBreadthWithId& other) const {
44     return breadth < other.breadth;
45   }
46 };
47 
48 }  // namespace
49 
GreedyByBreadthAssignment(const std::vector<TensorUsageRecord<size_t>> & usage_records,ObjectsAssignment<size_t> * assignment)50 absl::Status GreedyByBreadthAssignment(
51     const std::vector<TensorUsageRecord<size_t>>& usage_records,
52     ObjectsAssignment<size_t>* assignment) {
53   std::vector<TaskProfile> task_profiles = CalculateTaskProfiles(usage_records);
54 
55   // Task breadth is a sum of sizes of all tensors in its TaskProfile
56   std::vector<TaskBreadthWithId> task_breadth;
57   for (size_t task_id = 0; task_id < task_profiles.size(); ++task_id) {
58     size_t breadth = 0;
59     for (const auto& tensor_info : task_profiles[task_id]) {
60       breadth += tensor_info.usage_record->tensor_size;
61     }
62     task_breadth.emplace_back(breadth, task_id);
63   }
64 
65   assignment->object_sizes.clear();
66   assignment->object_ids.assign(usage_records.size(), kNotAssigned);
67   std::vector<SharedObjectSchedule> obj_schedules;
68 
69   // Iterate through all tasks in non-increasing order of their breadth.
70   std::sort(task_breadth.rbegin(), task_breadth.rend());
71   for (const auto& task : task_breadth) {
72     // Iterate through all tensors, that must be allocated during the execution
73     // of task, in non-increasing order of their tensor_size.
74     for (const auto& tensor_info : task_profiles[task.task_id]) {
75       if (assignment->object_ids[tensor_info.idx] != kNotAssigned) {
76         continue;
77       }
78       const auto& rec = *tensor_info.usage_record;
79       const size_t num_objects = obj_schedules.size();
80       size_t best_object = num_objects;
81       for (size_t obj_id = 0; obj_id < num_objects; ++obj_id) {
82         // If size of current_object is worse than size of best found before, we
83         // can skip it.
84         if (best_object != num_objects) {
85           const size_t best_size = assignment->object_sizes[best_object];
86           const size_t cur_size = assignment->object_sizes[obj_id];
87           if (best_size < rec.tensor_size) {
88             if (cur_size <= best_size) {
89               // best_size is smaller than tensor_size, but cur_size is even
90               // smaller.
91               continue;
92             }
93           } else if (cur_size < rec.tensor_size || cur_size >= best_size) {
94             // best_size is larger or equal to tensor_size, and cur_size is
95             // either smaller than tensor_size, or too large.
96             continue;
97           }
98         }
99         const auto& schedule = obj_schedules[obj_id];
100         auto it = schedule.lower_bound(rec);
101         bool update_best_object = true;
102         if (it != schedule.end() && it->first_task <= rec.last_task) {
103           // Some tensor, which usage interval intersects with current, already
104           // assigned to this object.
105           update_best_object = false;
106         }
107         if (update_best_object && it != schedule.begin()) {
108           it--;
109           if (it->last_task >= rec.first_task) {
110             // Some tensor, which usage interval intersects with current,
111             // already assigned to this object.
112             update_best_object = false;
113           }
114         }
115         if (update_best_object) {
116           best_object = obj_id;
117         }
118       }
119       if (best_object == num_objects) {
120         // Create new shared object and assign current tensor to it.
121         obj_schedules.push_back({rec});
122         assignment->object_sizes.push_back(rec.tensor_size);
123       } else {
124         // Assign current tensor to best_object.
125         obj_schedules[best_object].insert(rec);
126         // Size of best_object can be increased, if it is smaller than
127         // tensor_size.
128         assignment->object_sizes[best_object] =
129             std::max(assignment->object_sizes[best_object], rec.tensor_size);
130       }
131       assignment->object_ids[tensor_info.idx] = best_object;
132     }
133   }
134   // In the end all tensors must be assigned to some objects.
135   for (const auto& obj_id : assignment->object_ids) {
136     if (obj_id == kNotAssigned) {
137       return absl::InternalError("Error while calculating the assignment.");
138     }
139   }
140   return absl::OkStatus();
141 }
142 
143 }  // namespace gpu
144 }  // namespace tflite
145