• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/dataset_ops.h"
16 
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/data/captured_function.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/serialization_utils.h"
26 #include "tensorflow/core/framework/dataset.h"
27 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/grappler/graph_topology_view.h"
32 #include "tensorflow/core/grappler/utils/traversal.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34 
35 namespace tensorflow {
36 namespace data {
37 
38 /* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
39 /* static */ constexpr const char* const
40     DatasetToGraphOp::kStripDeviceAssignment;
41 /* static */ constexpr const char* const DatasetToGraphOp::kExternalStatePolicy;
42 /* static */ constexpr const char* const DatasetToGraphOp::kDatasetToGraph;
43 /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
44 /* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
45 
46 namespace {
47 constexpr char kPyFunc[] = "PyFunc";
48 }  // namespace
49 
50 // See documentation in ../../ops/dataset_ops.cc for a high-level
51 // description of the following op.
DatasetToGraphOp(OpKernelConstruction * ctx)52 DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
53     : OpKernel(ctx), op_version_(ctx->def().op() == kDatasetToGraph ? 1 : 2) {
54   if (op_version_ == 2) {
55     if (ctx->HasAttr(kExternalStatePolicy)) {
56       int64_t state_change_option;
57       OP_REQUIRES_OK(ctx,
58                      ctx->GetAttr(kExternalStatePolicy, &state_change_option));
59       external_state_policy_ =
60           SerializationContext::ExternalStatePolicy(state_change_option);
61     }
62   } else {
63     if (ctx->HasAttr(kAllowStateful)) {
64       bool allow_stateful;
65       OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
66       if (allow_stateful) {
67         external_state_policy_ =
68             SerializationContext::ExternalStatePolicy::kWarn;
69       } else {
70         external_state_policy_ =
71             SerializationContext::ExternalStatePolicy::kFail;
72       }
73     }
74   }
75 
76   if (ctx->HasAttr(kStripDeviceAssignment)) {
77     OP_REQUIRES_OK(
78         ctx, ctx->GetAttr(kStripDeviceAssignment, &strip_device_assignment_));
79   }
80 }
81 
Compute(OpKernelContext * ctx)82 void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
83   DatasetBase* dataset;
84   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
85   if (dataset->options().optional_external_state_policy_case() ==
86       Options::kExternalStatePolicy) {
87     switch (dataset->options().external_state_policy()) {
88       case ExternalStatePolicy::POLICY_WARN:
89         external_state_policy_ =
90             SerializationContext::ExternalStatePolicy::kWarn;
91         break;
92       case ExternalStatePolicy::POLICY_IGNORE:
93         external_state_policy_ =
94             SerializationContext::ExternalStatePolicy::kIgnore;
95         break;
96       case ExternalStatePolicy::POLICY_FAIL:
97         external_state_policy_ =
98             SerializationContext::ExternalStatePolicy::kFail;
99         break;
100       default: {
101         LOG(ERROR) << "Dataset " << dataset->type_string()
102                    << " has an unknown external_state_policy enum value: "
103                    << dataset->options().external_state_policy();
104       }
105     }
106   }
107   SerializationContext::Params params(ctx);
108   params.external_state_policy = external_state_policy_;
109 
110   GraphDef graph_def;
111   Status s = AsGraphDef(dataset, SerializationContext(params), &graph_def);
112   if (!s.ok()) {
113     ctx->CtxFailure(errors::FailedPrecondition(
114         "Failed to serialize the input pipeline graph: ", s.error_message()));
115     return;
116   }
117   if (strip_device_assignment_) {
118     auto library = graph_def.mutable_library();
119     for (auto& function : (*library->mutable_function())) {
120       for (auto& node : (*function.mutable_node_def())) {
121         // We do not strip the device assignment from `PyFunc` ops because they
122         // need to be pinned to a host that is known to have Python interpreter.
123         if (!node.device().empty() && node.op() != kPyFunc) {
124           *node.mutable_device() = DeviceNameUtils::LocalName(node.device());
125         }
126       }
127     }
128   }
129 
130   Tensor* result;
131   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
132   result->scalar<tstring>()() = graph_def.SerializeAsString();
133 }
134 
Compute(OpKernelContext * ctx)135 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
136   DatasetBase* dataset;
137   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
138   Tensor* result;
139   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
140   result->scalar<int64_t>()() = dataset->Cardinality();
141 }
142 
Compute(OpKernelContext * ctx)143 void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
144   tstring graph_def_string;
145   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kGraphDef, &graph_def_string));
146   GraphDef graph_def;
147   OP_REQUIRES(ctx, graph_def.ParseFromString(graph_def_string),
148               errors::InvalidArgument("Could not parse GraphDef"));
149   string output_node;
150   for (const auto& node : graph_def.node()) {
151     if (node.op() == FunctionLibraryDefinition::kRetOp) {
152       output_node = node.input(0);
153     }
154   }
155   Graph graph(OpRegistry::Global());
156   OP_REQUIRES_OK(ctx, ImportGraphDef({}, graph_def, &graph, nullptr));
157 
158   FunctionLibraryRuntime* flr;
159   std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
160   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
161   OP_REQUIRES_OK(ctx,
162                  ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
163 
164   // Some function names may be duplicated (for example, if the serialized
165   // graph has an optimized function that retains its original name). We
166   // override functions in flib_def in the event of conflict. It is
167   // safe to assume that any node in the serialized graph is referring to the
168   // serialized function when there is a conflict.
169   OP_REQUIRES_OK(ctx,
170                  AddToFunctionLibrary(flib_def.get(), graph_def.library()));
171 
172   std::vector<Tensor> outputs;
173   GraphRunner graph_runner(flr->device());
174   OP_REQUIRES_OK(ctx,
175                  graph_runner.Run(&graph, flr, {}, {output_node}, &outputs));
176   OP_REQUIRES_OK(ctx, ctx->set_output(kHandle, outputs[0]));
177 }
178 
179 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
180                         DatasetToGraphOp);
181 REGISTER_KERNEL_BUILDER(Name("DatasetToGraphV2").Device(DEVICE_CPU),
182                         DatasetToGraphOp);
183 
184 REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
185                         DatasetCardinalityOp);
186 REGISTER_KERNEL_BUILDER(
187     Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
188     DatasetCardinalityOp);
189 
190 REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
191                         DatasetFromGraphOp);
192 
193 }  // namespace data
194 }  // namespace tensorflow
195 #endif  // !IS_MOBILE_PLATFORM
196