1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/dataset_ops.h"
16
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/data/captured_function.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/serialization_utils.h"
26 #include "tensorflow/core/framework/dataset.h"
27 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/grappler/graph_topology_view.h"
32 #include "tensorflow/core/grappler/utils/traversal.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34
35 namespace tensorflow {
36 namespace data {
37
38 /* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
39 /* static */ constexpr const char* const
40 DatasetToGraphOp::kStripDeviceAssignment;
41 /* static */ constexpr const char* const DatasetToGraphOp::kExternalStatePolicy;
42 /* static */ constexpr const char* const DatasetToGraphOp::kDatasetToGraph;
43 /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
44 /* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
45
46 namespace {
47 constexpr char kPyFunc[] = "PyFunc";
48 } // namespace
49
50 // See documentation in ../../ops/dataset_ops.cc for a high-level
51 // description of the following op.
DatasetToGraphOp(OpKernelConstruction * ctx)52 DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
53 : OpKernel(ctx), op_version_(ctx->def().op() == kDatasetToGraph ? 1 : 2) {
54 if (op_version_ == 2) {
55 if (ctx->HasAttr(kExternalStatePolicy)) {
56 int64_t state_change_option;
57 OP_REQUIRES_OK(ctx,
58 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
59 external_state_policy_ =
60 SerializationContext::ExternalStatePolicy(state_change_option);
61 }
62 } else {
63 if (ctx->HasAttr(kAllowStateful)) {
64 bool allow_stateful;
65 OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
66 if (allow_stateful) {
67 external_state_policy_ =
68 SerializationContext::ExternalStatePolicy::kWarn;
69 } else {
70 external_state_policy_ =
71 SerializationContext::ExternalStatePolicy::kFail;
72 }
73 }
74 }
75
76 if (ctx->HasAttr(kStripDeviceAssignment)) {
77 OP_REQUIRES_OK(
78 ctx, ctx->GetAttr(kStripDeviceAssignment, &strip_device_assignment_));
79 }
80 }
81
Compute(OpKernelContext * ctx)82 void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
83 DatasetBase* dataset;
84 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
85 if (dataset->options().optional_external_state_policy_case() ==
86 Options::kExternalStatePolicy) {
87 switch (dataset->options().external_state_policy()) {
88 case ExternalStatePolicy::POLICY_WARN:
89 external_state_policy_ =
90 SerializationContext::ExternalStatePolicy::kWarn;
91 break;
92 case ExternalStatePolicy::POLICY_IGNORE:
93 external_state_policy_ =
94 SerializationContext::ExternalStatePolicy::kIgnore;
95 break;
96 case ExternalStatePolicy::POLICY_FAIL:
97 external_state_policy_ =
98 SerializationContext::ExternalStatePolicy::kFail;
99 break;
100 default: {
101 LOG(ERROR) << "Dataset " << dataset->type_string()
102 << " has an unknown external_state_policy enum value: "
103 << dataset->options().external_state_policy();
104 }
105 }
106 }
107 SerializationContext::Params params(ctx);
108 params.external_state_policy = external_state_policy_;
109
110 GraphDef graph_def;
111 Status s = AsGraphDef(dataset, SerializationContext(params), &graph_def);
112 if (!s.ok()) {
113 ctx->CtxFailure(errors::FailedPrecondition(
114 "Failed to serialize the input pipeline graph: ", s.error_message()));
115 return;
116 }
117 if (strip_device_assignment_) {
118 auto library = graph_def.mutable_library();
119 for (auto& function : (*library->mutable_function())) {
120 for (auto& node : (*function.mutable_node_def())) {
121 // We do not strip the device assignment from `PyFunc` ops because they
122 // need to be pinned to a host that is known to have Python interpreter.
123 if (!node.device().empty() && node.op() != kPyFunc) {
124 *node.mutable_device() = DeviceNameUtils::LocalName(node.device());
125 }
126 }
127 }
128 }
129
130 Tensor* result;
131 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
132 result->scalar<tstring>()() = graph_def.SerializeAsString();
133 }
134
Compute(OpKernelContext * ctx)135 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
136 DatasetBase* dataset;
137 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
138 Tensor* result;
139 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
140 result->scalar<int64_t>()() = dataset->Cardinality();
141 }
142
Compute(OpKernelContext * ctx)143 void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
144 tstring graph_def_string;
145 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kGraphDef, &graph_def_string));
146 GraphDef graph_def;
147 OP_REQUIRES(ctx, graph_def.ParseFromString(graph_def_string),
148 errors::InvalidArgument("Could not parse GraphDef"));
149 string output_node;
150 for (const auto& node : graph_def.node()) {
151 if (node.op() == FunctionLibraryDefinition::kRetOp) {
152 output_node = node.input(0);
153 }
154 }
155 Graph graph(OpRegistry::Global());
156 OP_REQUIRES_OK(ctx, ImportGraphDef({}, graph_def, &graph, nullptr));
157
158 FunctionLibraryRuntime* flr;
159 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
160 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
161 OP_REQUIRES_OK(ctx,
162 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
163
164 // Some function names may be duplicated (for example, if the serialized
165 // graph has an optimized function that retains its original name). We
166 // override functions in flib_def in the event of conflict. It is
167 // safe to assume that any node in the serialized graph is referring to the
168 // serialized function when there is a conflict.
169 OP_REQUIRES_OK(ctx,
170 AddToFunctionLibrary(flib_def.get(), graph_def.library()));
171
172 std::vector<Tensor> outputs;
173 GraphRunner graph_runner(flr->device());
174 OP_REQUIRES_OK(ctx,
175 graph_runner.Run(&graph, flr, {}, {output_node}, &outputs));
176 OP_REQUIRES_OK(ctx, ctx->set_output(kHandle, outputs[0]));
177 }
178
179 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
180 DatasetToGraphOp);
181 REGISTER_KERNEL_BUILDER(Name("DatasetToGraphV2").Device(DEVICE_CPU),
182 DatasetToGraphOp);
183
184 REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
185 DatasetCardinalityOp);
186 REGISTER_KERNEL_BUILDER(
187 Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
188 DatasetCardinalityOp);
189
190 REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
191 DatasetFromGraphOp);
192
193 } // namespace data
194 } // namespace tensorflow
195 #endif // !IS_MOBILE_PLATFORM
196