1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_ 17 #define TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_ 18 19 #include <functional> 20 21 #include "tensorflow/core/lib/core/threadpool.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 26 // DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you 27 // to directly specify the shard size. Use this function only if you want to 28 // manually cap parallelism. 29 // Shards the "total" unit of work assuming each unit of work having 30 // roughly "cost_per_unit". Each unit of work is indexed 0, 1, ..., 31 // total - 1. Each shard contains 1 or more units of work and the 32 // total cost of each shard is roughly the same. The calling thread and the 33 // "workers" are used to compute each shard (calling work(start, 34 // limit). A common configuration is that "workers" is a thread pool 35 // with at least "max_parallelism" threads. 36 // 37 // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds 38 // if not CPU-bound) to complete a unit of work. Overestimating creates too 39 // many shards and CPU time will be dominated by per-shard overhead, such as 40 // Context creation. Underestimating may not fully make use of the specified 41 // parallelism. 42 // 43 // "work" should be a callable taking (int64, int64) arguments. 44 // work(start, limit) computes the work units from [start, 45 // limit), i.e., [start, limit) is a shard. 46 // 47 // Too much parallelism can also cause excessive thread switches, 48 // therefore, Shard() often limits the maximum parallelism. Each 49 // caller can provide the 1st argument max_parallelism. A thread can 50 // call SetMaxParallelism() so that all Shard() calls later limits the 51 // thread parallelism. 52 // 53 // REQUIRES: max_parallelism >= 0 54 // REQUIRES: workers != nullptr 55 // REQUIRES: total >= 0 56 // REQUIRES: cost_per_unit >= 0 57 void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, 58 int64 cost_per_unit, std::function<void(int64, int64)> work); 59 60 // Each thread has an associated option to express the desired maximum 61 // parallelism. Its default is a very large quantity. 62 // 63 // Within TF runtime, per-thread max parallelism affects Shard() and 64 // intra-op parallelism. E.g., if SetPerThreadMaxParallelism(1) is 65 // arranged to be called by a tf_compute thread, Shard() calls and 66 // eigen device assignment happens in that thread afterwards becomes 67 // single-threaded. 68 void SetPerThreadMaxParallelism(int max_parallelism); 69 int GetPerThreadMaxParallelism(); 70 71 // Helper to set and unset per-thread max parallelism. 72 class ScopedPerThreadMaxParallelism { 73 public: ScopedPerThreadMaxParallelism(int max_parallelism)74 ScopedPerThreadMaxParallelism(int max_parallelism) 75 : previous_(GetPerThreadMaxParallelism()) { 76 SetPerThreadMaxParallelism(max_parallelism); 77 } 78 ~ScopedPerThreadMaxParallelism()79 ~ScopedPerThreadMaxParallelism() { SetPerThreadMaxParallelism(previous_); } 80 81 private: 82 int previous_ = -1; 83 }; 84 85 // Implementation details for Shard(). 86 class Sharder { 87 public: 88 typedef std::function<void()> Closure; 89 typedef std::function<void(Closure)> Runner; 90 typedef std::function<void(int64, int64)> Work; 91 92 // Refers to Shard()'s comment for the meaning of total, 93 // cost_per_unit, work, max_parallelism. runner is an interface to 94 // schedule a closure. Shard() uses thread::ThreadPool instead. 95 static void Do(int64 total, int64 cost_per_unit, const Work& work, 96 const Runner& runner, int max_parallelism); 97 }; 98 99 } // end namespace tensorflow 100 101 #endif // TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_ 102