1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h"
16
17 #include <memory>
18
19 #include "tensorflow/core/data/dataset_utils.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 #include "tensorflow/core/lib/core/threadpool.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/stringprintf.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/core/util/work_sharder.h"
29
30 namespace tensorflow {
31 namespace data {
32 namespace experimental {
33
34 /* static */ constexpr const char* const
35 MaxIntraOpParallelismDatasetOp::kDatasetType;
36 /* static */ constexpr const char* const
37 MaxIntraOpParallelismDatasetOp::kDatasetOp;
38 /* static */ constexpr const char* const
39 PrivateThreadPoolDatasetOp::kDatasetType;
40 /* static */ constexpr const char* const PrivateThreadPoolDatasetOp::kDatasetOp;
41
42 class ThreadPoolResource : public ResourceBase {
43 public:
ThreadPoolResource(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,int max_intra_op_parallelism)44 ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
45 const string& name, int num_threads, bool low_latency_hint,
46 int max_intra_op_parallelism)
47 : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
48 max_intra_op_parallelism_(max_intra_op_parallelism) {}
49
50 // Schedules fn() for execution in the pool of threads.
Schedule(std::function<void ()> fn)51 void Schedule(std::function<void()> fn) {
52 if (max_intra_op_parallelism_ < 0) {
53 thread_pool_.Schedule(std::move(fn));
54 } else {
55 thread_pool_.Schedule(std::bind(
56 [this](std::function<void()> bound_fn) {
57 // TODO(mrry): Consider moving this thread-local configuration to
58 // the threads themselves.
59 ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
60 bound_fn();
61 },
62 std::move(fn)));
63 }
64 }
65
NumThreads()66 int32 NumThreads() { return thread_pool_.NumThreads(); }
67
DebugString() const68 string DebugString() const override { return "ThreadPoolResource"; }
69
70 private:
71 thread::ThreadPool thread_pool_;
72 const int max_intra_op_parallelism_;
73 };
74
75 // Creates a handle to a ThreadPool resource. Note that we don't use
76 // ResourceOpKernel here because the ThreadPoolResource constructor requires
77 // access to `OpKernelContext::env()`, which isn't provided by
78 // `ResourceOpKernel<T>::CreateResource()`.
79 class ThreadPoolHandleOp : public OpKernel {
80 public:
ThreadPoolHandleOp(OpKernelConstruction * ctx)81 explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
82 OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
83 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
84 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
85 &max_intra_op_parallelism_));
86 OP_REQUIRES(
87 ctx, num_threads_ > 0,
88 errors::InvalidArgument("`num_threads` must be greater than zero."));
89 }
90
91 // The resource is deleted from the resource manager only when it is private
92 // to kernel. Ideally the resource should be deleted when it is no longer held
93 // by anyone, but it would break backward compatibility.
~ThreadPoolHandleOp()94 ~ThreadPoolHandleOp() override {
95 if (cinfo_.resource_is_private_to_kernel()) {
96 if (!cinfo_.resource_manager()
97 ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name())
98 .ok()) {
99 // Do nothing; the resource can have been deleted by session resets.
100 }
101 }
102 }
103
Compute(OpKernelContext * ctx)104 void Compute(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) {
105 mutex_lock l(mu_);
106 if (!initialized_) {
107 ResourceMgr* mgr = ctx->resource_manager();
108 OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
109 ThreadPoolResource* resource;
110 OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>(
111 cinfo_.container(), cinfo_.name(), &resource,
112 [this, ctx](ThreadPoolResource** ret)
113 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
114 *ret = new ThreadPoolResource(
115 ctx->env(), {}, display_name_,
116 num_threads_,
117 /*low_latency_hint=*/false,
118 max_intra_op_parallelism_);
119 return Status::OK();
120 }));
121 initialized_ = true;
122 }
123 OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
124 ctx, 0, cinfo_.container(), cinfo_.name(),
125 TypeIndex::Make<ThreadPoolResource>()));
126 }
127
128 private:
129 mutex mu_;
130 ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
131 bool initialized_ TF_GUARDED_BY(mu_) = false;
132 string display_name_;
133 int num_threads_;
134 int max_intra_op_parallelism_;
135 };
136
137 class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
138 public:
ThreadPoolDatasetOp(OpKernelConstruction * ctx)139 explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx)
140 : UnaryDatasetOpKernel(ctx) {}
141
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)142 void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
143 DatasetBase** output) override {
144 core::RefCountPtr<ThreadPoolResource> threadpool_resource;
145 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
146 &threadpool_resource));
147 *output = new Dataset(ctx, input, ctx->input(1), threadpool_resource.get());
148 }
149
150 private:
151 class Dataset : public DatasetBase {
152 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const Tensor & resource_handle,ThreadPoolResource * threadpool)153 Dataset(OpKernelContext* ctx, const DatasetBase* input,
154 const Tensor& resource_handle, ThreadPoolResource* threadpool)
155 : DatasetBase(DatasetContext(ctx)),
156 input_(input),
157 resource_handle_(resource_handle),
158 threadpool_(threadpool) {
159 input_->Ref();
160 threadpool_->Ref();
161 }
162
~Dataset()163 ~Dataset() override {
164 input_->Unref();
165 threadpool_->Unref();
166 }
167
MakeIteratorInternal(const string & prefix) const168 std::unique_ptr<IteratorBase> MakeIteratorInternal(
169 const string& prefix) const override {
170 return absl::make_unique<Iterator>(
171 Iterator::Params{this, strings::StrCat(prefix, "::ThreadPool")});
172 }
173
output_dtypes() const174 const DataTypeVector& output_dtypes() const override {
175 return input_->output_dtypes();
176 }
output_shapes() const177 const std::vector<PartialTensorShape>& output_shapes() const override {
178 return input_->output_shapes();
179 }
180
DebugString() const181 string DebugString() const override {
182 return "ThreadPoolDatasetOp::Dataset";
183 }
184
Cardinality() const185 int64 Cardinality() const override { return input_->Cardinality(); }
186
InputDatasets(std::vector<const DatasetBase * > * inputs) const187 Status InputDatasets(
188 std::vector<const DatasetBase*>* inputs) const override {
189 inputs->push_back(input_);
190 return Status::OK();
191 }
192
CheckExternalState() const193 Status CheckExternalState() const override {
194 return input_->CheckExternalState();
195 }
196
197 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const198 Status AsGraphDefInternal(SerializationContext* ctx,
199 DatasetGraphDefBuilder* b,
200 Node** output) const override {
201 Node* input_graph_node = nullptr;
202 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
203 Node* resource_handle_node = nullptr;
204 TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
205 TF_RETURN_IF_ERROR(b->AddDataset(
206 this, {input_graph_node, resource_handle_node}, output));
207 return Status::OK();
208 }
209
210 private:
211 class Iterator : public DatasetIterator<Dataset> {
212 public:
Iterator(const Params & params)213 explicit Iterator(const Params& params)
214 : DatasetIterator<Dataset>(params) {}
215
Initialize(IteratorContext * ctx)216 Status Initialize(IteratorContext* ctx) override {
217 return dataset()->input_->MakeIterator(
218 IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_);
219 }
220
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)221 Status GetNextInternal(IteratorContext* ctx,
222 std::vector<Tensor>* out_tensors,
223 bool* end_of_sequence) override {
224 return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
225 out_tensors, end_of_sequence);
226 }
227
228 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const229 std::shared_ptr<model::Node> CreateNode(
230 IteratorContext* ctx, model::Node::Args args) const override {
231 return model::MakeKnownRatioNode(std::move(args),
232 /*ratio=*/1);
233 }
234
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)235 Status SaveInternal(SerializationContext* ctx,
236 IteratorStateWriter* writer) override {
237 DCHECK(input_impl_ != nullptr);
238 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
239 return Status::OK();
240 }
241
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)242 Status RestoreInternal(IteratorContext* ctx,
243 IteratorStateReader* reader) override {
244 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
245 return Status::OK();
246 }
247
248 private:
CreateParams(IteratorContext * ctx)249 IteratorContext::Params CreateParams(IteratorContext* ctx) {
250 ThreadPoolResource* pool = dataset()->threadpool_;
251 IteratorContext::Params params(ctx);
252 params.runner = [pool](std::function<void()> c) {
253 pool->Schedule(std::move(c));
254 };
255 params.runner_threadpool_size = pool->NumThreads();
256 return params;
257 }
258
259 std::unique_ptr<IteratorBase> input_impl_;
260 };
261
262 const DatasetBase* const input_;
263 const Tensor resource_handle_;
264 ThreadPoolResource* const threadpool_;
265 };
266 };
267
268 class MaxIntraOpParallelismDatasetOp::Dataset : public DatasetBase {
269 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t max_intra_op_parallelism)270 Dataset(OpKernelContext* ctx, const DatasetBase* input,
271 int64_t max_intra_op_parallelism)
272 : Dataset(DatasetContext(ctx), input, max_intra_op_parallelism) {}
273
Dataset(DatasetContext && ctx,const DatasetBase * input,int64_t max_intra_op_parallelism)274 Dataset(DatasetContext&& ctx, const DatasetBase* input,
275 int64_t max_intra_op_parallelism)
276 : DatasetBase(std::move(ctx)),
277 input_(input),
278 max_intra_op_parallelism_(max_intra_op_parallelism),
279 traceme_metadata_(
280 {{"parallelism",
281 strings::Printf("%lld", static_cast<long long>(
282 max_intra_op_parallelism_))}}) {
283 input_->Ref();
284 }
285
~Dataset()286 ~Dataset() override { input_->Unref(); }
287
MakeIteratorInternal(const string & prefix) const288 std::unique_ptr<IteratorBase> MakeIteratorInternal(
289 const string& prefix) const override {
290 return absl::make_unique<Iterator>(Iterator::Params{
291 this, strings::StrCat(prefix, "::MaxIntraOpParallelism")});
292 }
293
output_dtypes() const294 const DataTypeVector& output_dtypes() const override {
295 return input_->output_dtypes();
296 }
output_shapes() const297 const std::vector<PartialTensorShape>& output_shapes() const override {
298 return input_->output_shapes();
299 }
300
DebugString() const301 string DebugString() const override {
302 return "MaxIntraOpParallelismDatasetOp::Dataset";
303 }
304
Cardinality() const305 int64 Cardinality() const override { return input_->Cardinality(); }
306
InputDatasets(std::vector<const DatasetBase * > * inputs) const307 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
308 inputs->clear();
309 inputs->push_back(input_);
310 return Status::OK();
311 }
312
CheckExternalState() const313 Status CheckExternalState() const override {
314 return input_->CheckExternalState();
315 }
316
317 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const318 Status AsGraphDefInternal(SerializationContext* ctx,
319 DatasetGraphDefBuilder* b,
320 Node** output) const override {
321 Node* input_graph_node = nullptr;
322 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
323 Node* max_intra_op_parallelism_node = nullptr;
324 TF_RETURN_IF_ERROR(b->AddScalar(max_intra_op_parallelism_,
325 &max_intra_op_parallelism_node));
326 TF_RETURN_IF_ERROR(b->AddDataset(
327 this, {input_graph_node, max_intra_op_parallelism_node}, output));
328 return Status::OK();
329 }
330
331 private:
332 class Iterator : public DatasetIterator<Dataset> {
333 public:
Iterator(const Params & params)334 explicit Iterator(const Params& params)
335 : DatasetIterator<Dataset>(params) {}
336
Initialize(IteratorContext * ctx)337 Status Initialize(IteratorContext* ctx) override {
338 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
339 }
340
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)341 Status GetNextInternal(IteratorContext* ctx,
342 std::vector<Tensor>* out_tensors,
343 bool* end_of_sequence) override {
344 IteratorContext::Params params(ctx);
345 auto max_parallelism = dataset()->max_intra_op_parallelism_;
346 params.runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
347 return input_impl_->GetNext(IteratorContext{std::move(params)},
348 out_tensors, end_of_sequence);
349 }
350
351 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const352 std::shared_ptr<model::Node> CreateNode(
353 IteratorContext* ctx, model::Node::Args args) const override {
354 return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
355 }
356
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)357 Status SaveInternal(SerializationContext* ctx,
358 IteratorStateWriter* writer) override {
359 DCHECK(input_impl_ != nullptr);
360 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
361 return Status::OK();
362 }
363
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)364 Status RestoreInternal(IteratorContext* ctx,
365 IteratorStateReader* reader) override {
366 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
367 return Status::OK();
368 }
369
GetTraceMeMetadata() const370 TraceMeMetadata GetTraceMeMetadata() const override {
371 return dataset()->traceme_metadata_;
372 }
373
374 private:
375 std::unique_ptr<IteratorBase> input_impl_;
376 };
377
378 const DatasetBase* const input_;
379 const int64 max_intra_op_parallelism_;
380 const TraceMeMetadata traceme_metadata_;
381 };
382
383 /* static */
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,int32_t max_intra_op_parallelism,DatasetBase ** output)384 void MaxIntraOpParallelismDatasetOp::MakeDatasetFromOptions(
385 OpKernelContext* ctx, DatasetBase* input, int32_t max_intra_op_parallelism,
386 DatasetBase** output) {
387 OP_REQUIRES(
388 ctx, max_intra_op_parallelism >= 0,
389 errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
390 *output = new Dataset(DatasetContext(DatasetContext::Params(
391 {MaxIntraOpParallelismDatasetOp::kDatasetType,
392 MaxIntraOpParallelismDatasetOp::kDatasetOp})),
393 input, max_intra_op_parallelism);
394 }
395
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)396 void MaxIntraOpParallelismDatasetOp::MakeDataset(OpKernelContext* ctx,
397 DatasetBase* input,
398 DatasetBase** output) {
399 int64_t max_intra_op_parallelism;
400 OP_REQUIRES_OK(ctx,
401 ParseScalarArgument<int64>(ctx, "max_intra_op_parallelism",
402 &max_intra_op_parallelism));
403 OP_REQUIRES(
404 ctx, max_intra_op_parallelism >= 0,
405 errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
406 *output = new Dataset(ctx, input, max_intra_op_parallelism);
407 }
408
409 class PrivateThreadPoolDatasetOp::Dataset : public DatasetBase {
410 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int num_threads)411 Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
412 : Dataset(ctx, DatasetContext(ctx), input, num_threads) {}
413
Dataset(OpKernelContext * ctx,DatasetContext && dataset_ctx,const DatasetBase * input,int num_threads)414 Dataset(OpKernelContext* ctx, DatasetContext&& dataset_ctx,
415 const DatasetBase* input, int num_threads)
416 : DatasetBase(std::move(dataset_ctx)),
417 input_(input),
418 num_threads_(num_threads == 0 ? port::MaxParallelism() : num_threads),
419 traceme_metadata_(
420 {{"num_threads",
421 strings::Printf("%lld", static_cast<long long>(num_threads_))}}) {
422 thread_pool_ = absl::make_unique<thread::ThreadPool>(
423 ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads_);
424 input_->Ref();
425 }
426
~Dataset()427 ~Dataset() override { input_->Unref(); }
428
MakeIteratorInternal(const string & prefix) const429 std::unique_ptr<IteratorBase> MakeIteratorInternal(
430 const string& prefix) const override {
431 return absl::make_unique<Iterator>(
432 Iterator::Params{this, strings::StrCat(prefix, "::PrivateThreadPool")});
433 }
434
output_dtypes() const435 const DataTypeVector& output_dtypes() const override {
436 return input_->output_dtypes();
437 }
output_shapes() const438 const std::vector<PartialTensorShape>& output_shapes() const override {
439 return input_->output_shapes();
440 }
441
DebugString() const442 string DebugString() const override {
443 return "PrivateThreadPoolDatasetOp::Dataset";
444 }
445
Cardinality() const446 int64 Cardinality() const override { return input_->Cardinality(); }
447
InputDatasets(std::vector<const DatasetBase * > * inputs) const448 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
449 inputs->clear();
450 inputs->push_back(input_);
451 return Status::OK();
452 }
453
CheckExternalState() const454 Status CheckExternalState() const override {
455 return input_->CheckExternalState();
456 }
457
458 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const459 Status AsGraphDefInternal(SerializationContext* ctx,
460 DatasetGraphDefBuilder* b,
461 Node** output) const override {
462 Node* input_graph_node = nullptr;
463 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
464 Node* num_threads_node = nullptr;
465 TF_RETURN_IF_ERROR(b->AddScalar(num_threads_, &num_threads_node));
466 TF_RETURN_IF_ERROR(
467 b->AddDataset(this, {input_graph_node, num_threads_node}, output));
468 return Status::OK();
469 }
470
471 private:
472 class Iterator : public DatasetIterator<Dataset> {
473 public:
Iterator(const Params & params)474 explicit Iterator(const Params& params)
475 : DatasetIterator<Dataset>(params) {}
476
Initialize(IteratorContext * ctx)477 Status Initialize(IteratorContext* ctx) override {
478 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
479 }
480
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)481 Status GetNextInternal(IteratorContext* ctx,
482 std::vector<Tensor>* out_tensors,
483 bool* end_of_sequence) override {
484 thread::ThreadPool* pool = dataset()->thread_pool_.get();
485 IteratorContext::Params params(ctx);
486 params.runner = [pool](std::function<void()> c) {
487 pool->Schedule(std::move(c));
488 };
489 params.runner_threadpool_size = dataset()->num_threads_;
490 return input_impl_->GetNext(IteratorContext{std::move(params)},
491 out_tensors, end_of_sequence);
492 }
493
494 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const495 std::shared_ptr<model::Node> CreateNode(
496 IteratorContext* ctx, model::Node::Args args) const override {
497 return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
498 }
499
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)500 Status SaveInternal(SerializationContext* ctx,
501 IteratorStateWriter* writer) override {
502 DCHECK(input_impl_ != nullptr);
503 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
504 return Status::OK();
505 }
506
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)507 Status RestoreInternal(IteratorContext* ctx,
508 IteratorStateReader* reader) override {
509 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
510 return Status::OK();
511 }
512
GetTraceMeMetadata() const513 TraceMeMetadata GetTraceMeMetadata() const override {
514 return dataset()->traceme_metadata_;
515 }
516
517 private:
518 std::unique_ptr<IteratorBase> input_impl_;
519 };
520
521 const DatasetBase* const input_;
522 const int64 num_threads_;
523 const TraceMeMetadata traceme_metadata_;
524 std::unique_ptr<thread::ThreadPool> thread_pool_;
525 };
526
527 /* static */
MakeDatasetFromOptions(OpKernelContext * ctx,DatasetBase * input,int32_t num_threads,DatasetBase ** output)528 void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx,
529 DatasetBase* input,
530 int32_t num_threads,
531 DatasetBase** output) {
532 OP_REQUIRES(ctx, num_threads >= 0,
533 errors::InvalidArgument("`num_threads` must be >= 0"));
534 *output = new Dataset(ctx,
535 DatasetContext(DatasetContext::Params(
536 {PrivateThreadPoolDatasetOp::kDatasetType,
537 PrivateThreadPoolDatasetOp::kDatasetOp})),
538 input, num_threads);
539 }
540
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)541 void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx,
542 DatasetBase* input,
543 DatasetBase** output) {
544 int64_t num_threads = 0;
545 OP_REQUIRES_OK(ctx,
546 ParseScalarArgument<int64>(ctx, "num_threads", &num_threads));
547 OP_REQUIRES(ctx, num_threads >= 0,
548 errors::InvalidArgument("`num_threads` must be >= 0"));
549 *output = new Dataset(ctx, input, num_threads);
550 }
551
552 namespace {
553
554 REGISTER_KERNEL_BUILDER(Name("MaxIntraOpParallelismDataset").Device(DEVICE_CPU),
555 MaxIntraOpParallelismDatasetOp);
556 REGISTER_KERNEL_BUILDER(
557 Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
558 MaxIntraOpParallelismDatasetOp);
559
560 REGISTER_KERNEL_BUILDER(Name("PrivateThreadPoolDataset").Device(DEVICE_CPU),
561 PrivateThreadPoolDatasetOp);
562 REGISTER_KERNEL_BUILDER(
563 Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
564 PrivateThreadPoolDatasetOp);
565
566 REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
567 ThreadPoolHandleOp);
568 REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
569 ThreadPoolHandleOp);
570
571 REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
572 ThreadPoolDatasetOp);
573 REGISTER_KERNEL_BUILDER(
574 Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
575 ThreadPoolDatasetOp);
576
577 } // namespace
578 } // namespace experimental
579 } // namespace data
580 } // namespace tensorflow
581