• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h"
16 
17 #include <memory>
18 
19 #include "tensorflow/core/data/dataset_utils.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 #include "tensorflow/core/lib/core/threadpool.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/stringprintf.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/util/work_sharder.h"
29 
30 namespace tensorflow {
31 namespace data {
32 namespace experimental {
33 
34 /* static */ constexpr const char* const
35     MaxIntraOpParallelismDatasetOp::kDatasetType;
36 /* static */ constexpr const char* const
37     MaxIntraOpParallelismDatasetOp::kDatasetOp;
38 /* static */ constexpr const char* const
39     PrivateThreadPoolDatasetOp::kDatasetType;
40 /* static */ constexpr const char* const PrivateThreadPoolDatasetOp::kDatasetOp;
41 
42 class ThreadPoolResource : public ResourceBase {
43  public:
ThreadPoolResource(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,int max_intra_op_parallelism)44   ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
45                      const string& name, int num_threads, bool low_latency_hint,
46                      int max_intra_op_parallelism)
47       : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
48         max_intra_op_parallelism_(max_intra_op_parallelism) {}
49 
50   // Schedules fn() for execution in the pool of threads.
Schedule(std::function<void ()> fn)51   void Schedule(std::function<void()> fn) {
52     if (max_intra_op_parallelism_ < 0) {
53       thread_pool_.Schedule(std::move(fn));
54     } else {
55       thread_pool_.Schedule(std::bind(
56           [this](std::function<void()> bound_fn) {
57             // TODO(mrry): Consider moving this thread-local configuration to
58             // the threads themselves.
59             ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
60             bound_fn();
61           },
62           std::move(fn)));
63     }
64   }
65 
NumThreads()66   int32 NumThreads() { return thread_pool_.NumThreads(); }
67 
DebugString() const68   string DebugString() const override { return "ThreadPoolResource"; }
69 
70  private:
71   thread::ThreadPool thread_pool_;
72   const int max_intra_op_parallelism_;
73 };
74 
75 // Creates a handle to a ThreadPool resource. Note that we don't use
76 // ResourceOpKernel here because the ThreadPoolResource constructor requires
77 // access to `OpKernelContext::env()`, which isn't provided by
78 // `ResourceOpKernel<T>::CreateResource()`.
79 class ThreadPoolHandleOp : public OpKernel {
80  public:
ThreadPoolHandleOp(OpKernelConstruction * ctx)81   explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
82     OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
83     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
84     OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
85                                      &max_intra_op_parallelism_));
86     OP_REQUIRES(
87         ctx, num_threads_ > 0,
88         errors::InvalidArgument("`num_threads` must be greater than zero."));
89   }
90 
91   // The resource is deleted from the resource manager only when it is private
92   // to kernel. Ideally the resource should be deleted when it is no longer held
93   // by anyone, but it would break backward compatibility.
~ThreadPoolHandleOp()94   ~ThreadPoolHandleOp() override {
95     if (cinfo_.resource_is_private_to_kernel()) {
96       if (!cinfo_.resource_manager()
97                ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name())
98                .ok()) {
99         // Do nothing; the resource can have been deleted by session resets.
100       }
101     }
102   }
103 
Compute(OpKernelContext * ctx)104   void Compute(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) {
105     mutex_lock l(mu_);
106     if (!initialized_) {
107       ResourceMgr* mgr = ctx->resource_manager();
108       OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
109       ThreadPoolResource* resource;
110       OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>(
111                               cinfo_.container(), cinfo_.name(), &resource,
112                               [this, ctx](ThreadPoolResource** ret)
113                                   TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
114                                     *ret = new ThreadPoolResource(
115                                         ctx->env(), {}, display_name_,
116                                         num_threads_,
117                                         /*low_latency_hint=*/false,
118                                         max_intra_op_parallelism_);
119                                     return Status::OK();
120                                   }));
121       initialized_ = true;
122     }
123     OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
124                             ctx, 0, cinfo_.container(), cinfo_.name(),
125                             TypeIndex::Make<ThreadPoolResource>()));
126   }
127 
128  private:
129   mutex mu_;
130   ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
131   bool initialized_ TF_GUARDED_BY(mu_) = false;
132   string display_name_;
133   int num_threads_;
134   int max_intra_op_parallelism_;
135 };
136 
137 class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
138  public:
ThreadPoolDatasetOp(OpKernelConstruction * ctx)139   explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx)
140       : UnaryDatasetOpKernel(ctx) {}
141 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)142   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
143                    DatasetBase** output) override {
144     core::RefCountPtr<ThreadPoolResource> threadpool_resource;
145     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
146                                        &threadpool_resource));
147     *output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get());
148   }
149 
150  private:
151   class Dataset : public DatasetBase {
152    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const Tensor & resource_handle,ThreadPoolResource * threadpool)153     Dataset(OpKernelContext* ctx, const DatasetBase* input,
154             const Tensor& resource_handle, ThreadPoolResource* threadpool)
155         : DatasetBase(DatasetContext(ctx)),
156           input_(input),
157           resource_handle_(resource_handle),
158           threadpool_(threadpool) {
159       input_->Ref();
160       threadpool_->Ref();
161     }
162 
~Dataset()163     ~Dataset() override {
164       input_->Unref();
165       threadpool_->Unref();
166     }
167 
MakeIteratorInternal(const string & prefix) const168     std::unique_ptr<IteratorBase> MakeIteratorInternal(
169         const string& prefix) const override {
170       return absl::make_unique<Iterator>(
171           Iterator::Params{this, strings::StrCat(prefix, "::ThreadPool")});
172     }
173 
output_dtypes() const174     const DataTypeVector& output_dtypes() const override {
175       return input_->output_dtypes();
176     }
output_shapes() const177     const std::vector<PartialTensorShape>& output_shapes() const override {
178       return input_->output_shapes();
179     }
180 
DebugString() const181     string DebugString() const override {
182       return "ThreadPoolDatasetOp::Dataset";
183     }
184 
Cardinality() const185     int64 Cardinality() const override { return input_->Cardinality(); }
186 
InputDatasets(std::vector<const DatasetBase * > * inputs) const187     Status InputDatasets(
188         std::vector<const DatasetBase*>* inputs) const override {
189       inputs->push_back(input_);
190       return Status::OK();
191     }
192 
CheckExternalState() const193     Status CheckExternalState() const override {
194       return input_->CheckExternalState();
195     }
196 
197    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const198     Status AsGraphDefInternal(SerializationContext* ctx,
199                               DatasetGraphDefBuilder* b,
200                               Node** output) const override {
201       Node* input_graph_node = nullptr;
202       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
203       Node* resource_handle_node = nullptr;
204       TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
205       TF_RETURN_IF_ERROR(b->AddDataset(
206           this, {input_graph_node, resource_handle_node}, output));
207       return Status::OK();
208     }
209 
210    private:
211     class Iterator : public DatasetIterator<Dataset> {
212      public:
Iterator(const Params & params)213       explicit Iterator(const Params& params)
214           : DatasetIterator<Dataset>(params) {}
215 
Initialize(IteratorContext * ctx)216       Status Initialize(IteratorContext* ctx) override {
217         return dataset()->input_->MakeIterator(
218             IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_);
219       }
220 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)221       Status GetNextInternal(IteratorContext* ctx,
222                              std::vector<Tensor>* out_tensors,
223                              bool* end_of_sequence) override {
224         return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
225                                     out_tensors, end_of_sequence);
226       }
227 
228      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const229       std::shared_ptr<model::Node> CreateNode(
230           IteratorContext* ctx, model::Node::Args args) const override {
231         return model::MakeKnownRatioNode(std::move(args),
232                                          /*ratio=*/1);
233       }
234 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)235       Status SaveInternal(SerializationContext* ctx,
236                           IteratorStateWriter* writer) override {
237         DCHECK(input_impl_ != nullptr);
238         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
239         return Status::OK();
240       }
241 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)242       Status RestoreInternal(IteratorContext* ctx,
243                              IteratorStateReader* reader) override {
244         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
245         return Status::OK();
246       }
247 
248      private:
CreateParams(IteratorContext * ctx)249       IteratorContext::Params CreateParams(IteratorContext* ctx) {
250         ThreadPoolResource* pool = dataset()->threadpool_;
251         IteratorContext::Params params(ctx);
252         params.runner = [pool](std::function<void()> c) {
253           pool->Schedule(std::move(c));
254         };
255         params.runner_threadpool_size = pool->NumThreads();
256         return params;
257       }
258 
259       std::unique_ptr<IteratorBase> input_impl_;
260     };
261 
262     const DatasetBase* const input_;
263     const Tensor resource_handle_;
264     ThreadPoolResource* const threadpool_;
265   };
266 };
267 
268 class MaxIntraOpParallelismDatasetOp::Dataset : public DatasetBase {
269  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t max_intra_op_parallelism)270   Dataset(OpKernelContext* ctx, const DatasetBase* input,
271           int64_t max_intra_op_parallelism)
272       : Dataset(DatasetContext(ctx), input, max_intra_op_parallelism) {}
273 
Dataset(DatasetContext && ctx,const DatasetBase * input,int64_t max_intra_op_parallelism)274   Dataset(DatasetContext&& ctx, const DatasetBase* input,
275           int64_t max_intra_op_parallelism)
276       : DatasetBase(std::move(ctx)),
277         input_(input),
278         max_intra_op_parallelism_(max_intra_op_parallelism),
279         traceme_metadata_(
280             {{"parallelism",
281               strings::Printf("%lld", static_cast<long long>(
282                                           max_intra_op_parallelism_))}}) {
283     input_->Ref();
284   }
285 
~Dataset()286   ~Dataset() override { input_->Unref(); }
287 
MakeIteratorInternal(const string & prefix) const288   std::unique_ptr<IteratorBase> MakeIteratorInternal(
289       const string& prefix) const override {
290     return absl::make_unique<Iterator>(Iterator::Params{
291         this, strings::StrCat(prefix, "::MaxIntraOpParallelism")});
292   }
293 
output_dtypes() const294   const DataTypeVector& output_dtypes() const override {
295     return input_->output_dtypes();
296   }
output_shapes() const297   const std::vector<PartialTensorShape>& output_shapes() const override {
298     return input_->output_shapes();
299   }
300 
DebugString() const301   string DebugString() const override {
302     return "MaxIntraOpParallelismDatasetOp::Dataset";
303   }
304 
Cardinality() const305   int64 Cardinality() const override { return input_->Cardinality(); }
306 
InputDatasets(std::vector<const DatasetBase * > * inputs) const307   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
308     inputs->clear();
309     inputs->push_back(input_);
310     return Status::OK();
311   }
312 
CheckExternalState() const313   Status CheckExternalState() const override {
314     return input_->CheckExternalState();
315   }
316 
317  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const318   Status AsGraphDefInternal(SerializationContext* ctx,
319                             DatasetGraphDefBuilder* b,
320                             Node** output) const override {
321     Node* input_graph_node = nullptr;
322     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
323     Node* max_intra_op_parallelism_node = nullptr;
324     TF_RETURN_IF_ERROR(b->AddScalar(max_intra_op_parallelism_,
325                                     &max_intra_op_parallelism_node));
326     TF_RETURN_IF_ERROR(b->AddDataset(
327         this, {input_graph_node, max_intra_op_parallelism_node}, output));
328     return Status::OK();
329   }
330 
331  private:
332   class Iterator : public DatasetIterator<Dataset> {
333    public:
Iterator(const Params & params)334     explicit Iterator(const Params& params)
335         : DatasetIterator<Dataset>(params) {}
336 
Initialize(IteratorContext * ctx)337     Status Initialize(IteratorContext* ctx) override {
338       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
339     }
340 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)341     Status GetNextInternal(IteratorContext* ctx,
342                            std::vector<Tensor>* out_tensors,
343                            bool* end_of_sequence) override {
344       IteratorContext::Params params(ctx);
345       auto max_parallelism = dataset()->max_intra_op_parallelism_;
346       params.runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
347       return input_impl_->GetNext(IteratorContext{std::move(params)},
348                                   out_tensors, end_of_sequence);
349     }
350 
351    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const352     std::shared_ptr<model::Node> CreateNode(
353         IteratorContext* ctx, model::Node::Args args) const override {
354       return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
355     }
356 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)357     Status SaveInternal(SerializationContext* ctx,
358                         IteratorStateWriter* writer) override {
359       DCHECK(input_impl_ != nullptr);
360       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
361       return Status::OK();
362     }
363 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)364     Status RestoreInternal(IteratorContext* ctx,
365                            IteratorStateReader* reader) override {
366       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
367       return Status::OK();
368     }
369 
GetTraceMeMetadata() const370     TraceMeMetadata GetTraceMeMetadata() const override {
371       return dataset()->traceme_metadata_;
372     }
373 
374    private:
375     std::unique_ptr<IteratorBase> input_impl_;
376   };
377 
378   const DatasetBase* const input_;
379   const int64 max_intra_op_parallelism_;
380   const TraceMeMetadata traceme_metadata_;
381 };
382 
383 /* static */
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,int32_t max_intra_op_parallelism,DatasetBase ** output)384 void MaxIntraOpParallelismDatasetOp::MakeDatasetFromOptions(
385     OpKernelContext* ctx, DatasetBase* input, int32_t max_intra_op_parallelism,
386     DatasetBase** output) {
387   OP_REQUIRES(
388       ctx, max_intra_op_parallelism >= 0,
389       errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
390   *output = new Dataset(DatasetContext(DatasetContext::Params(
391                             {MaxIntraOpParallelismDatasetOp::kDatasetType,
392                              MaxIntraOpParallelismDatasetOp::kDatasetOp})),
393                         input, max_intra_op_parallelism);
394 }
395 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)396 void MaxIntraOpParallelismDatasetOp::MakeDataset(OpKernelContext* ctx,
397                                                  DatasetBase* input,
398                                                  DatasetBase** output) {
399   int64_t max_intra_op_parallelism;
400   OP_REQUIRES_OK(ctx,
401                  ParseScalarArgument<int64>(ctx, "max_intra_op_parallelism",
402                                             &max_intra_op_parallelism));
403   OP_REQUIRES(
404       ctx, max_intra_op_parallelism >= 0,
405       errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
406   *output = new Dataset(ctx, input, max_intra_op_parallelism);
407 }
408 
409 class PrivateThreadPoolDatasetOp::Dataset : public DatasetBase {
410  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int num_threads)411   Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
412       : Dataset(ctx, DatasetContext(ctx), input, num_threads) {}
413 
Dataset(OpKernelContext * ctx,DatasetContext && dataset_ctx,const DatasetBase * input,int num_threads)414   Dataset(OpKernelContext* ctx, DatasetContext&& dataset_ctx,
415           const DatasetBase* input, int num_threads)
416       : DatasetBase(std::move(dataset_ctx)),
417         input_(input),
418         num_threads_(num_threads == 0 ? port::MaxParallelism() : num_threads),
419         traceme_metadata_(
420             {{"num_threads",
421               strings::Printf("%lld", static_cast<long long>(num_threads_))}}) {
422     thread_pool_ = absl::make_unique<thread::ThreadPool>(
423         ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads_);
424     input_->Ref();
425   }
426 
~Dataset()427   ~Dataset() override { input_->Unref(); }
428 
MakeIteratorInternal(const string & prefix) const429   std::unique_ptr<IteratorBase> MakeIteratorInternal(
430       const string& prefix) const override {
431     return absl::make_unique<Iterator>(
432         Iterator::Params{this, strings::StrCat(prefix, "::PrivateThreadPool")});
433   }
434 
output_dtypes() const435   const DataTypeVector& output_dtypes() const override {
436     return input_->output_dtypes();
437   }
output_shapes() const438   const std::vector<PartialTensorShape>& output_shapes() const override {
439     return input_->output_shapes();
440   }
441 
DebugString() const442   string DebugString() const override {
443     return "PrivateThreadPoolDatasetOp::Dataset";
444   }
445 
Cardinality() const446   int64 Cardinality() const override { return input_->Cardinality(); }
447 
InputDatasets(std::vector<const DatasetBase * > * inputs) const448   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
449     inputs->clear();
450     inputs->push_back(input_);
451     return Status::OK();
452   }
453 
CheckExternalState() const454   Status CheckExternalState() const override {
455     return input_->CheckExternalState();
456   }
457 
458  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const459   Status AsGraphDefInternal(SerializationContext* ctx,
460                             DatasetGraphDefBuilder* b,
461                             Node** output) const override {
462     Node* input_graph_node = nullptr;
463     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
464     Node* num_threads_node = nullptr;
465     TF_RETURN_IF_ERROR(b->AddScalar(num_threads_, &num_threads_node));
466     TF_RETURN_IF_ERROR(
467         b->AddDataset(this, {input_graph_node, num_threads_node}, output));
468     return Status::OK();
469   }
470 
471  private:
472   class Iterator : public DatasetIterator<Dataset> {
473    public:
Iterator(const Params & params)474     explicit Iterator(const Params& params)
475         : DatasetIterator<Dataset>(params) {}
476 
Initialize(IteratorContext * ctx)477     Status Initialize(IteratorContext* ctx) override {
478       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
479     }
480 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)481     Status GetNextInternal(IteratorContext* ctx,
482                            std::vector<Tensor>* out_tensors,
483                            bool* end_of_sequence) override {
484       thread::ThreadPool* pool = dataset()->thread_pool_.get();
485       IteratorContext::Params params(ctx);
486       params.runner = [pool](std::function<void()> c) {
487         pool->Schedule(std::move(c));
488       };
489       params.runner_threadpool_size = dataset()->num_threads_;
490       return input_impl_->GetNext(IteratorContext{std::move(params)},
491                                   out_tensors, end_of_sequence);
492     }
493 
494    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const495     std::shared_ptr<model::Node> CreateNode(
496         IteratorContext* ctx, model::Node::Args args) const override {
497       return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
498     }
499 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)500     Status SaveInternal(SerializationContext* ctx,
501                         IteratorStateWriter* writer) override {
502       DCHECK(input_impl_ != nullptr);
503       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
504       return Status::OK();
505     }
506 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)507     Status RestoreInternal(IteratorContext* ctx,
508                            IteratorStateReader* reader) override {
509       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
510       return Status::OK();
511     }
512 
GetTraceMeMetadata() const513     TraceMeMetadata GetTraceMeMetadata() const override {
514       return dataset()->traceme_metadata_;
515     }
516 
517    private:
518     std::unique_ptr<IteratorBase> input_impl_;
519   };
520 
521   const DatasetBase* const input_;
522   const int64 num_threads_;
523   const TraceMeMetadata traceme_metadata_;
524   std::unique_ptr<thread::ThreadPool> thread_pool_;
525 };
526 
527 /* static */
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,int32_t num_threads,DatasetBase ** output)528 void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
529                                                         DatasetBase* input,
530                                                         int32_t num_threads,
531                                                         DatasetBase** output) {
532   OP_REQUIRES(ctx, num_threads >= 0,
533               errors::InvalidArgument("`num_threads` must be >= 0"));
534   *output = new Dataset(ctx,
535                         DatasetContext(DatasetContext::Params(
536                             {PrivateThreadPoolDatasetOp::kDatasetType,
537                              PrivateThreadPoolDatasetOp::kDatasetOp})),
538                         input, num_threads);
539 }
540 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)541 void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx,
542                                              DatasetBase* input,
543                                              DatasetBase** output) {
544   int64_t num_threads = 0;
545   OP_REQUIRES_OK(ctx,
546                  ParseScalarArgument<int64>(ctx, "num_threads", &num_threads));
547   OP_REQUIRES(ctx, num_threads >= 0,
548               errors::InvalidArgument("`num_threads` must be >= 0"));
549   *output = new Dataset(ctx, input, num_threads);
550 }
551 
552 namespace {
553 
554 REGISTER_KERNEL_BUILDER(Name("MaxIntraOpParallelismDataset").Device(DEVICE_CPU),
555                         MaxIntraOpParallelismDatasetOp);
556 REGISTER_KERNEL_BUILDER(
557     Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
558     MaxIntraOpParallelismDatasetOp);
559 
560 REGISTER_KERNEL_BUILDER(Name("PrivateThreadPoolDataset").Device(DEVICE_CPU),
561                         PrivateThreadPoolDatasetOp);
562 REGISTER_KERNEL_BUILDER(
563     Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
564     PrivateThreadPoolDatasetOp);
565 
566 REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
567                         ThreadPoolHandleOp);
568 REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
569                         ThreadPoolHandleOp);
570 
571 REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
572                         ThreadPoolDatasetOp);
573 REGISTER_KERNEL_BUILDER(
574     Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
575     ThreadPoolDatasetOp);
576 
577 }  // namespace
578 }  // namespace experimental
579 }  // namespace data
580 }  // namespace tensorflow
581