• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/dataset_ops.h"
16 
17 // On mobile we do not provide this functionality because not all of its
18 // dependencies are available there.
19 #if !defined(IS_MOBILE_PLATFORM)
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/common_runtime/graph_runner.h"
22 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/graph/graph_def_builder.h"
28 #include "tensorflow/core/grappler/graph_topology_view.h"
29 #include "tensorflow/core/grappler/utils/traversal.h"
30 #include "tensorflow/core/kernels/data/captured_function.h"
31 #include "tensorflow/core/kernels/data/dataset_utils.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 
34 namespace tensorflow {
35 namespace data {
36 
37 /* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
38 /* static */ constexpr const char* const
39     DatasetToGraphOp::kStripDeviceAssignment;
40 /* static */ constexpr const char* const DatasetToGraphOp::kExternalStatePolicy;
41 /* static */ constexpr const char* const DatasetToGraphOp::kDatasetToGraph;
42 /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
43 /* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
44 
45 namespace {
46 constexpr char kPyFunc[] = "PyFunc";
47 }  // namespace
48 
49 // See documentation in ../../ops/dataset_ops.cc for a high-level
50 // description of the following op.
DatasetToGraphOp(OpKernelConstruction * ctx)51 DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
52     : OpKernel(ctx), op_version_(ctx->def().op() == kDatasetToGraph ? 1 : 2) {
53   if (op_version_ == 2) {
54     if (ctx->HasAttr(kExternalStatePolicy)) {
55       int64 state_change_option;
56       OP_REQUIRES_OK(ctx,
57                      ctx->GetAttr(kExternalStatePolicy, &state_change_option));
58       external_state_policy_ =
59           SerializationContext::ExternalStatePolicy(state_change_option);
60     }
61   } else {
62     if (ctx->HasAttr(kAllowStateful)) {
63       bool allow_stateful;
64       OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
65       if (allow_stateful) {
66         external_state_policy_ =
67             SerializationContext::ExternalStatePolicy::kWarn;
68       } else {
69         external_state_policy_ =
70             SerializationContext::ExternalStatePolicy::kFail;
71       }
72     }
73   }
74 
75   if (ctx->HasAttr(kStripDeviceAssignment)) {
76     OP_REQUIRES_OK(
77         ctx, ctx->GetAttr(kStripDeviceAssignment, &strip_device_assignment_));
78   }
79 }
80 
Compute(OpKernelContext * ctx)81 void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
82   DatasetBase* dataset;
83   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
84   SerializationContext::Params params;
85   params.external_state_policy = external_state_policy_;
86 
87   GraphDef graph_def;
88   Status s = AsGraphDef(ctx, dataset, SerializationContext(params), &graph_def);
89   if (!s.ok()) {
90     ctx->CtxFailure(errors::FailedPrecondition(
91         "Failed to serialize the input pipeline graph: ", s.error_message()));
92     return;
93   }
94   if (strip_device_assignment_) {
95     auto library = graph_def.mutable_library();
96     for (auto& function : (*library->mutable_function())) {
97       for (auto& node : (*function.mutable_node_def())) {
98         // We do not strip the device assignment from `PyFunc` ops because they
99         // need to be pinned to a host that is known to have Python interpreter.
100         if (!node.device().empty() && node.op() != kPyFunc) {
101           *node.mutable_device() = DeviceNameUtils::LocalName(node.device());
102         }
103       }
104     }
105   }
106 
107   Tensor* result;
108   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
109   result->scalar<tstring>()() = graph_def.SerializeAsString();
110 }
111 
Compute(OpKernelContext * ctx)112 void DatasetCardinalityOp::Compute(OpKernelContext* ctx) {
113   DatasetBase* dataset;
114   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
115   Tensor* result;
116   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
117   result->scalar<int64>()() = dataset->Cardinality();
118 }
119 
Compute(OpKernelContext * ctx)120 void DatasetFromGraphOp::Compute(OpKernelContext* ctx) {
121   tstring graph_def_string;
122   OP_REQUIRES_OK(ctx,
123                  ParseScalarArgument(ctx, kGraphDef, &graph_def_string));
124   GraphDef graph_def;
125   OP_REQUIRES(ctx, graph_def.ParseFromString(graph_def_string),
126               errors::InvalidArgument("Could not parse GraphDef"));
127   string output_node;
128   for (const auto& node : graph_def.node()) {
129     if (node.op() == FunctionLibraryDefinition::kRetOp) {
130       output_node = node.input(0);
131     }
132   }
133   Graph graph(OpRegistry::Global());
134   OP_REQUIRES_OK(ctx, ImportGraphDef({}, graph_def, &graph, nullptr));
135 
136   FunctionLibraryRuntime* flr;
137   std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
138   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
139   OP_REQUIRES_OK(ctx,
140                  ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
141 
142   // Some function names may be duplicated (for example, if the serialized
143   // graph has an optimized function that retains its original name). We
144   // override functions in flib_def in the event of conflict. It is
145   // safe to assume that any node in the serialized graph is referring to the
146   // serialized function when there is a conflict.
147   OP_REQUIRES_OK(ctx,
148                  AddToFunctionLibrary(flib_def.get(), graph_def.library()));
149 
150   std::vector<Tensor> outputs;
151   GraphRunner graph_runner(flr->device());
152   OP_REQUIRES_OK(ctx,
153                  graph_runner.Run(&graph, flr, {}, {output_node}, &outputs));
154   OP_REQUIRES_OK(ctx, ctx->set_output(kHandle, outputs[0]));
155 }
156 
157 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
158                         DatasetToGraphOp);
159 REGISTER_KERNEL_BUILDER(Name("DatasetToGraphV2").Device(DEVICE_CPU),
160                         DatasetToGraphOp);
161 
162 REGISTER_KERNEL_BUILDER(Name("DatasetCardinality").Device(DEVICE_CPU),
163                         DatasetCardinalityOp);
164 REGISTER_KERNEL_BUILDER(
165     Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
166     DatasetCardinalityOp);
167 
168 REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU),
169                         DatasetFromGraphOp);
170 
171 }  // namespace data
172 }  // namespace tensorflow
173 #endif  // !IS_MOBILE_PLATFORM
174