| /external/tensorflow/tensorflow/core/grappler/optimizers/ |
| D | remapper_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 62 item.fetch = {"batch_norm"}; in TEST_F() 64 auto tensors_expected = EvaluateNodes(item.graph, item.fetch); in TEST_F() 71 auto tensors = EvaluateNodes(output, item.fetch); in TEST_F() 73 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); in TEST_F() 99 item.fetch = {"batch_norm"}; in TEST_F() 108 auto tensors_expected = EvaluateNodes(item.graph, item.fetch); in TEST_F() 110 auto tensors = EvaluateNodes(output, item.fetch); in TEST_F() 112 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-3); in TEST_F() 159 auto fetch = ops::Identity(s.WithOpName("fetch"), relu); in TEST_F() local [all …]
|
| D | mkl_remapper_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 78 auto fetch = s.WithOpName("fetch"); in FuseConv2DWithBiasAndAddNOrAdd() local 80 ops::Identity(fetch, ops::Relu(activate, addop)); in FuseConv2DWithBiasAndAddNOrAdd() 82 ops::Identity(fetch, ops::Relu6(activate, addop)); in FuseConv2DWithBiasAndAddNOrAdd() 84 ops::Identity(fetch, ops::Elu(activate, addop)); in FuseConv2DWithBiasAndAddNOrAdd() 86 ops::Identity(fetch, ops::internal::LeakyRelu(activate, addop)); in FuseConv2DWithBiasAndAddNOrAdd() 89 ops::Identity(fetch, addop); in FuseConv2DWithBiasAndAddNOrAdd() 114 item.fetch = {"fetch"}; in FuseConv2DWithBiasAndAddNOrAdd() 123 item.graph.mutable_node(i)->set_device("/device:CPU:0"); in FuseConv2DWithBiasAndAddNOrAdd() 136 for (const NodeDef& node : output.node()) { in FuseConv2DWithBiasAndAddNOrAdd() local [all …]
|
| D | auto_parallel.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 36 NodeDef* node = graph_.add_node(); in AddNodeDivConst() local 37 node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const")); in AddNodeDivConst() 38 node->set_op("Const"); in AddNodeDivConst() 42 node->mutable_attr()->insert({"dtype", attr_data_type}); in AddNodeDivConst() 46 tensor->add_float_val(static_cast<float>(num_replicas_)); in AddNodeDivConst() 47 tensor->set_dtype(DT_FLOAT); in AddNodeDivConst() 48 node->mutable_attr()->insert({"value", attr_tensor}); in AddNodeDivConst() 49 return node; in AddNodeDivConst() 54 NodeDef* node = graph_.add_node(); in AddNodeDiv() local [all …]
|
| D | constant_folding_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 83 item.fetch = {"mul1", "mul2", "add1", "add2"}; in SimpleNeutralElementTest() 93 const NodeDef& node = output.node(i); in SimpleNeutralElementTest() local 94 const string& name = node.name(); in SimpleNeutralElementTest() 96 EXPECT_EQ("Const", node.op()); in SimpleNeutralElementTest() 97 EXPECT_EQ("^x", node.input(0)); in SimpleNeutralElementTest() 98 EXPECT_EQ("^zeros", node.input(1)); in SimpleNeutralElementTest() 100 EXPECT_EQ(snapshot_or_identity, node.op()); in SimpleNeutralElementTest() 101 EXPECT_EQ("x", node.input(0)); in SimpleNeutralElementTest() 102 EXPECT_EQ("^ones", node.input(1)); in SimpleNeutralElementTest() [all …]
|
| D | arithmetic_optimizer_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 58 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation. 63 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation. 68 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation. 73 // Optimized name of Const node by SimplifyAggregation. 78 // Optimized name of Mul node by SimplifyAggregation. 87 const NodeDef& original = original_graph.node(i); in VerifyGraphsMatch() 88 const NodeDef& optimized = optimized_graph.node(i); in VerifyGraphsMatch() 123 item.fetch = {"output"}; in TEST_F() 127 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}}); in TEST_F() [all …]
|
| D | dependency_optimizer_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 42 const NodeDef& original = original_graph.node(i); in VerifyGraphsEqual() 43 const NodeDef& optimized = optimized_graph.node(i); in VerifyGraphsEqual() 81 item.fetch.push_back("id1"); in TEST_F() 82 item.fetch.push_back("id2"); in TEST_F() 93 // The 'z' node should have been optimized away leaving only 5 nodes. in TEST_F() 96 for (const NodeDef& node : item.graph.node()) { in TEST_F() local 97 if (node.name() == "id1" || node.name() == "id2") { in TEST_F() 98 EXPECT_EQ(1, node.input_size()); in TEST_F() 99 EXPECT_EQ("add", node.input(0)); in TEST_F() [all …]
|
| D | pin_to_host_optimizer_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 57 item.fetch = {"a", "c", "d", "e", "f"}; in TEST_F() 60 auto tensors_expected = EvaluateNodes(item.graph, item.fetch); in TEST_F() 66 auto tensors = EvaluateNodes(item.graph, item.fetch); in TEST_F() 77 for (const NodeDef& node : output.node()) { in TEST_F() local 78 if (node.name() == "a" || node.name() == "c") { in TEST_F() 79 EXPECT_TRUE(node.device().empty()); in TEST_F() 80 } else if (node.name() == "d" || node.name() == "e" || node.name() == "f") { in TEST_F() 81 EXPECT_EQ(node.device(), "/device:CPU:0"); in TEST_F() 97 item.fetch = {"b"}; in TEST_F() [all …]
|
| D | shape_optimizer_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 45 item.fetch = {"e", "f"}; in TEST_F() 48 auto tensors_expected = EvaluateNodes(item.graph, item.fetch); in TEST_F() 55 for (const NodeDef& node : output.node()) { in TEST_F() local 56 if (node.name() == "e") { in TEST_F() 58 EXPECT_EQ("Size", node.op()); in TEST_F() 59 EXPECT_EQ("a", node.input(0)); in TEST_F() 60 } else if (node.name() == "f") { in TEST_F() 62 EXPECT_EQ("Prod", node.op()); in TEST_F() 63 EXPECT_EQ("c", node.input(0)); in TEST_F() [all …]
|
| D | model_pruner_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 87 std::vector<string> fetch = {"e"}; in TEST_F() local 88 auto expected_tensors = EvaluateNodes(item.graph, fetch); in TEST_F() 89 auto actual_tensors = EvaluateNodes(output, fetch); in TEST_F() 109 item.fetch.push_back("e"); in TEST_F() 128 auto actual_tensors = EvaluateNodes(output, item.fetch); in TEST_F() 130 auto expected_tensors = EvaluateNodes(item.graph, item.fetch); in TEST_F() 142 // Node "c" is pruned along with fanins of node "c". in TEST_F() 144 // Node "d" will be pruned because it only has control outputs. in TEST_F() 155 item.fetch = {"g", "h"}; in TEST_F() [all …]
|
| D | auto_mixed_precision_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 66 (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval; in GenerateRandomTensorInRange() 76 const NodeDef& original = original_graph.node(i); in VerifyGraphsEquivalent() 115 device_properties.mutable_environment()->insert({"architecture", "7"}); in SetUp() 116 device_properties.mutable_environment()->insert({"cuda", "9010"}); in SetUp() 118 device_properties.mutable_environment()->insert( in SetUp() 124 TF_CHECK_OK(virtual_cluster_->Provision()); in SetUp() 127 void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } in TearDown() 154 type_list.mutable_list()->add_type(DT_FLOAT); in AddSimpleNode() 182 item.fetch = {"fetch1"}; in TestSimpleUnaryInferOp() [all …]
|
| /external/tensorflow/tensorflow/core/grappler/ |
| D | grappler_item_builder.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 58 auto flat = tensor->flat<float>(); in InitializeTensor() 64 auto flat = tensor->flat<int64_t>(); in InitializeTensor() 72 // Allocator will run non-trivial constructor/destructor for a Tensor with in InitializeTensor() 74 memset(const_cast<char*>(tensor->tensor_data().data()), 0, in InitializeTensor() 75 tensor->tensor_data().size()); in InitializeTensor() 86 item->graph = std::move(pruned_graph); in PruneGraph() 99 dim_proto.size() == -1) { in ReplaceUnknownShapeDim() 101 shape_pb_out->add_dim()->set_size( in ReplaceUnknownShapeDim() 105 shape_pb_out->add_dim()->set_size(dim_proto.size()); in ReplaceUnknownShapeDim() [all …]
|
| D | grappler_item.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 37 // prune all nodes that are not in the transitive fanin of the fetch nodes. in CreateOptOptionsForEager() 56 item.fetch = fetch; in WithGraph() 72 TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes)); in MainOpsFanin() 84 TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes)); in EnqueueOpsFanin() 98 for (const NodeDef* node : fanin) { in MainVariables() local 99 if (IsVariable(*node)) { in MainVariables() 100 vars.push_back(node); in MainVariables() 108 for (const string& f : fetch) { in NodesToPreserve() 109 VLOG(1) << "Add fetch " << f; in NodesToPreserve() [all …]
|
| /external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/ |
| D | target.pbtxt | 1 # RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-contro… 2 …-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-n… 3 …-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-n… 14 node { 35 node { 64 node { 96 node { 115 node { 136 node { 166 # Tests single target node with no pruning set. All nodes will remain in the [all …]
|
| /external/tensorflow/tensorflow/compiler/tf2xla/ |
| D | graph_compiler_util.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 47 typedef std::unordered_map<string, Node*> NodeMap; 49 // Each feed id identifies the positional output of some node, which may consist 52 // point from a new _Arg node instead. The newly created _Arg nodes are added to 57 std::unordered_set<const Node*>* arg_nodes) { in AddArgNodes() 65 auto node_it = node_map.find(remap_it->second); in AddArgNodes() 68 absl::string_view name(remap_it->second); in AddArgNodes() 72 "Node is fed but not needed for fetching: ", name); in AddArgNodes() 74 const Node* feed_node = node_it->second; in AddArgNodes() 80 Node* arg_node = nullptr; in AddArgNodes() [all …]
|
| D | tf2xla_util_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 52 feed->mutable_id()->set_node_name("foo"); in TEST() 53 feed->mutable_id()->set_output_index(123); in TEST() 54 feed->set_name("foo_debug"); in TEST() 56 feed->mutable_id()->set_node_name("bar"); in TEST() 57 feed->mutable_id()->set_output_index(0); in TEST() 58 tf2xla::Fetch* fetch = config.add_fetch(); in TEST() local 59 fetch->mutable_id()->set_node_name("baz"); in TEST() 60 fetch->mutable_id()->set_output_index(456); in TEST() 61 fetch->set_name("baz_debug"); in TEST() [all …]
|
| D | tf2xla.proto | 14 // index of a particular node in the graph. If the output of the named node 15 // feeds into other node(s), this corresponds to one or more edges. Otherwise 30 // contains this information. However, if the node being fed is an op that is 31 // not linked into the binary, then the type cannot be inferred from the node; 36 // Fetch represents a single fetch tensor in the graph, which corresponds to an 38 message Fetch { message 55 // Flag for variables that are never assigned. Assignments to a read-only 56 // variable or unassigned variables that are not read-only are invalid. 65 // Each fetch is a positional output argument for the generated computation. 67 repeated Fetch fetch = 2; field
|
| /external/tensorflow/tensorflow/core/grappler/optimizers/data/ |
| D | graph_utils.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 57 names[graph.node(i).name()] = i; in CreateNameIndex() 67 std::vector<int> CreateInputIndex(const NodeDef& node) { in CreateInputIndex() argument 69 for (int i = 0; i < node.input_size(); ++i) { in CreateInputIndex() 70 inputs[node.input(i)] = i; in CreateInputIndex() 72 std::vector<int> index(node.input_size()); in CreateInputIndex() 83 NodeDef node; in AddScalarConstNodeHelper() local 84 node.set_op(kConstOpName); in AddScalarConstNodeHelper() 85 SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node); in AddScalarConstNodeHelper() 87 (*node.mutable_attr())["dtype"].set_type(dtype); in AddScalarConstNodeHelper() [all …]
|
| D | use_private_thread_pool.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 47 if (item.fetch.size() != 1) { in OptimizeAndCollectStats() 49 "Expected only one fetch node but there were ", item.fetch.size(), ": ", in OptimizeAndCollectStats() 50 absl::StrJoin(item.fetch, ", ")); in OptimizeAndCollectStats() 53 for (const NodeDef& node : item.graph.node()) { in OptimizeAndCollectStats() local 54 if (node.op() == kPrivateThreadPoolDataset) { in OptimizeAndCollectStats() 61 NodeDef* sink_node = graph.GetNode(item.fetch.at(0)); in OptimizeAndCollectStats() 70 if (last_node->op() == kModelDataset) { in OptimizeAndCollectStats() 74 // Add a const node with value 0 to indicate it is not set by users. in OptimizeAndCollectStats() 84 *insert_node.mutable_input()->Add() = last_node->name(); in OptimizeAndCollectStats() [all …]
|
| D | disable_intra_op_parallelism.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 53 if (item.fetch.size() != 1) { in OptimizeAndCollectStats() 55 "Expected only one fetch node but there were ", item.fetch.size(), ": ", in OptimizeAndCollectStats() 56 absl::StrJoin(item.fetch, ", ")); in OptimizeAndCollectStats() 59 for (const NodeDef& node : item.graph.node()) { in OptimizeAndCollectStats() local 61 if (node.op() == target_dataset_op) { in OptimizeAndCollectStats() 69 NodeDef* sink_node = graph.GetNode(item.fetch.at(0)); in OptimizeAndCollectStats() 78 if (last_node->op() == kModelDataset) { in OptimizeAndCollectStats() 82 // Add a const node with value 1 in OptimizeAndCollectStats() 92 *insert_node.mutable_input()->Add() = last_node->name(); in OptimizeAndCollectStats() [all …]
|
| D | slack.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 39 bool IsDatasetNodeOfType(const NodeDef& node, in IsDatasetNodeOfType() argument 42 if (node.op() == dataset_op_name) return true; in IsDatasetNodeOfType() 82 if (dataset_node->op() == kPrefetchDatasetOp) { in RecursivelyHandleOp() 84 (*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_); in RecursivelyHandleOp() 96 for (int i = 0; i < dataset_node->input_size(); ++i) { in RecursivelyHandleOp() 124 if (item.fetch.size() != 1) { in OptimizeAndCollectStats() 126 "Expected only one fetch node but there were ", item.fetch.size(), ": ", in OptimizeAndCollectStats() 127 absl::StrJoin(item.fetch, ", ")); in OptimizeAndCollectStats() 129 // Walks the input pipeline backwards from the fetch node to find the last in OptimizeAndCollectStats() [all …]
|
| /external/tensorflow/tensorflow/core/grappler/clusters/ |
| D | single_machine_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 53 TF_CHECK_OK(cluster_->EnablePeakMemoryStats()); in SetUp() 54 TF_CHECK_OK(cluster_->Provision()); in SetUp() 59 TF_CHECK_OK(cluster_->Shutdown()); in TearDown() 69 CHECK_EQ("single_machine", cluster_->type()); in TEST_F() 74 cluster_->GetDeviceNames()); in TEST_F() 78 TF_CHECK_OK(cluster_->Initialize(item)); in TEST_F() 81 const int64_t start_micros = Env::Default()->NowMicros(); in TEST_F() 82 TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata)); in TEST_F() 84 Env::Default()->NowMicros() - start_micros; in TEST_F() [all …]
|
| D | single_machine.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 56 options_.config.add_session_inter_op_thread_pool()->set_num_threads( in SingleMachine() 66 // Reset the thread-pool so that there are no outstanding Session::Run(...)s in ~SingleMachine() 84 TF_RETURN_IF_ERROR(session_->ListDevices(&devices)); in Provision() 125 mutex_lock l(this->last_graph_mu_); in Initialize() 139 mutex_lock l(this->last_graph_mu_); in Shutdown() 148 const std::vector<string>& fetch, in Run() argument 150 mutex_lock l(this->last_graph_mu_); in Run() 153 TF_RETURN_IF_ERROR(session_->Create(graph_def)); in Run() 162 for (auto node : *init_metadata_.mutable_cost_graph()->mutable_node()) { in Run() [all …]
|
| /external/tensorflow/tensorflow/core/common_runtime/ |
| D | graph_execution_state.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 102 if (options.session_options->config.graph_options().place_pruned_graph() || in MakeForBaseGraph() 103 !options.session_options->config.experimental() in MakeForBaseGraph() 112 if (!options.session_options->config.graph_options().place_pruned_graph()) { in MakeForBaseGraph() 114 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_, in MakeForBaseGraph() 116 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); in MakeForBaseGraph() 125 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph))); in MakeForBaseGraph() 137 if (!(base_execution_state.session_options_->config.graph_options() in MakeForPrunedGraph() 139 options.session_options->config.graph_options().place_pruned_graph())) { in MakeForPrunedGraph() 150 "the Session-level `GraphExecutionState`."); in MakeForPrunedGraph() [all …]
|
| /external/tensorflow/tensorflow/c/experimental/grappler/ |
| D | grappler_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 37 params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE; in PopulateDefaultParam() 38 params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; in PopulateDefaultParam() 39 params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE; in PopulateDefaultParam() 40 params->optimizer->create_func = nullptr; in PopulateDefaultParam() 41 params->optimizer->optimize_func = optimize_func; in PopulateDefaultParam() 42 params->optimizer->destroy_func = nullptr; in PopulateDefaultParam() 47 TF_Status* const status) -> void { in TEST() 50 params->device_type = "Success"; in TEST() 51 params->optimizer_configs->remapping = TF_TriState_Off; in TEST() [all …]
|
| /external/tensorflow/tensorflow/core/graph/ |
| D | subgraph_test.cc | 7 http://www.apache.org/licenses/LICENSE-2.0 58 Node* FindNode(const string& name) { in FindNode() 59 for (Node* n : g_->nodes()) { in FindNode() 60 if (n->name() == name) return n; in FindNode() 70 for (Node* n : g_->nodes()) { in ExpectNodes() 71 if (n->IsOp()) { in ExpectNodes() 73 actual_nodes.push_back(n->name()); in ExpectNodes() 83 Node* n = FindNode(s); in ExpectNodes() 85 if (n->type_string() == "_Send" || n->type_string() == "_Recv") { in ExpectNodes() 86 EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; in ExpectNodes() [all …]
|