1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/dataset_ops.h"
16
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/graph/graph_def_builder.h"
28 #include "tensorflow/core/grappler/graph_topology_view.h"
29 #include "tensorflow/core/grappler/utils/traversal.h"
30 #include "tensorflow/core/kernels/data/captured_function.h"
31 #include "tensorflow/core/kernels/data/dataset_utils.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33
34 namespace tensorflow {
35 namespace data {
36
37 /* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
38 /* static */ constexpr const char* const
39 DatasetToGraphOp::kStripDeviceAssignment;
40 /* static */ constexpr const char* const DatasetToGraphOp::kExternalStatePolicy;
41 /* static */ constexpr const char* const DatasetToGraphOp::kDatasetToGraph;
42 /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
43 /* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
44
45 namespace {
46 constexpr char kPyFunc[] = "PyFunc";
47 } // namespace
48
49 // See documentation in ../../ops/dataset_ops.cc for a high-level
50 // description of the following op.
DatasetToGraphOp(OpKernelConstruction * ctx)51 DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
52 : OpKernel(ctx), op_version_(ctx->def().op() == kDatasetToGraph ? 1 : 2) {
53 if (op_version_ == 2) {
54 if (ctx->HasAttr(kExternalStatePolicy)) {
55 int64 state_change_option;
56 OP_REQUIRES_OK(ctx,
57 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
58 external_state_policy_ =
59 SerializationContext::ExternalStatePolicy(state_change_option);
60 }
61 } else {
62 if (ctx->HasAttr(kAllowStateful)) {
63 bool allow_stateful;
64 OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
65 if (allow_stateful) {
66 external_state_policy_ =
67 SerializationContext::ExternalStatePolicy::kWarn;
68 } else {
69 external_state_policy_ =
70 SerializationContext::ExternalStatePolicy::kFail;
71 }
72 }
73 }
74
75 if (ctx->HasAttr(kStripDeviceAssignment)) {
76 OP_REQUIRES_OK(
77 ctx, ctx->GetAttr(kStripDeviceAssignment, &strip_device_assignment_));
78 }
79 }
80
Compute(OpKernelContext * ctx)81 void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
82 DatasetBase* dataset;
83 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
84 SerializationContext::Params params;
85 params.external_state_policy = external_state_policy_;
86
87 GraphDef graph_def;
88 Status s = AsGraphDef(ctx, dataset, SerializationContext(params), &graph_def);
89 if (!s.ok()) {
90 ctx->CtxFailure(errors::FailedPrecondition(
91 "Failed to serialize the input pipeline graph: ", s.error_message()));
92 return;
93 }
94 if (strip_device_assignment_) {
95 auto library = graph_def.mutable_library();
96 for (auto& function : (*library->mutable_function())) {
97 for (auto& node : (*function.mutable_node_def())) {
98 // We do not strip the device assignment from `PyFunc` ops because they
99 // need to be pinned to a host that is known to have Python interpreter.
100 if (!node.device().empty() && node.op() != kPyFunc) {
101 *node.mutable_device() = DeviceNameUtils::LocalName(node.device());
102 }
103 }
104 }
105 }
106
107 Tensor* result;
108 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
109 result->scalar<tstring>()() = graph_def.SerializeAsString();
110 }
111
Compute(OpKernelContext * ctx)112 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
113 DatasetBase* dataset;
114 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
115 Tensor* result;
116 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
117 result->scalar<int64>()() = dataset->Cardinality();
118 }
119
Compute(OpKernelContext * ctx)120 void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
121 tstring graph_def_string;
122 OP_REQUIRES_OK(ctx,
123 ParseScalarArgument(ctx, kGraphDef, &graph_def_string));
124 GraphDef graph_def;
125 OP_REQUIRES(ctx, graph_def.ParseFromString(graph_def_string),
126 errors::InvalidArgument("Could not parse GraphDef"));
127 string output_node;
128 for (const auto& node : graph_def.node()) {
129 if (node.op() == FunctionLibraryDefinition::kRetOp) {
130 output_node = node.input(0);
131 }
132 }
133 Graph graph(OpRegistry::Global());
134 OP_REQUIRES_OK(ctx, ImportGraphDef({}, graph_def, &graph, nullptr));
135
136 FunctionLibraryRuntime* flr;
137 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
138 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
139 OP_REQUIRES_OK(ctx,
140 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
141
142 // Some function names may be duplicated (for example, if the serialized
143 // graph has an optimized function that retains its original name). We
144 // override functions in flib_def in the event of conflict. It is
145 // safe to assume that any node in the serialized graph is referring to the
146 // serialized function when there is a conflict.
147 OP_REQUIRES_OK(ctx,
148 AddToFunctionLibrary(flib_def.get(), graph_def.library()));
149
150 std::vector<Tensor> outputs;
151 GraphRunner graph_runner(flr->device());
152 OP_REQUIRES_OK(ctx,
153 graph_runner.Run(&graph, flr, {}, {output_node}, &outputs));
154 OP_REQUIRES_OK(ctx, ctx->set_output(kHandle, outputs[0]));
155 }
156
157 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
158 DatasetToGraphOp);
159 REGISTER_KERNEL_BUILDER(Name("DatasetToGraphV2").Device(DEVICE_CPU),
160 DatasetToGraphOp);
161
162 REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
163 DatasetCardinalityOp);
164 REGISTER_KERNEL_BUILDER(
165 Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
166 DatasetCardinalityOp);
167
168 REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
169 DatasetFromGraphOp);
170
171 } // namespace data
172 } // namespace tensorflow
173 #endif // !IS_MOBILE_PLATFORM
174