1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/tensor_id.h"
17 #include <vector>
18 #include "tensorflow/core/lib/random/simple_philox.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21 #include "tensorflow/core/platform/test_benchmark.h"
22
23 namespace tensorflow {
24 namespace {
25
ParseHelper(const string & n)26 string ParseHelper(const string& n) { return ParseTensorName(n).ToString(); }
27
TEST(TensorIdTest,ParseTensorName)28 TEST(TensorIdTest, ParseTensorName) {
29 EXPECT_EQ(ParseHelper("W1"), "W1:0");
30 EXPECT_EQ(ParseHelper("W1:0"), "W1:0");
31 EXPECT_EQ(ParseHelper("weights:0"), "weights:0");
32 EXPECT_EQ(ParseHelper("W1:1"), "W1:1");
33 EXPECT_EQ(ParseHelper("W1:17"), "W1:17");
34 EXPECT_EQ(ParseHelper("xyz1_17"), "xyz1_17:0");
35 EXPECT_EQ(ParseHelper("^foo"), "^foo");
36 }
37
Skewed(random::SimplePhilox * rnd,int max_log)38 uint32 Skewed(random::SimplePhilox* rnd, int max_log) {
39 const uint32 space = 1 << (rnd->Rand32() % (max_log + 1));
40 return rnd->Rand32() % space;
41 }
42
BM_ParseTensorName(::testing::benchmark::State & state)43 void BM_ParseTensorName(::testing::benchmark::State& state) {
44 const int arg = state.range(0);
45 random::PhiloxRandom philox(301, 17);
46 random::SimplePhilox rnd(&philox);
47 std::vector<string> names;
48 for (int i = 0; i < 100; i++) {
49 string name;
50 switch (arg) {
51 case 0: { // Generate random names
52 size_t len = Skewed(&rnd, 4);
53 while (name.size() < len) {
54 name += rnd.OneIn(4) ? '0' : 'a';
55 }
56 if (rnd.OneIn(3)) {
57 strings::StrAppend(&name, ":", rnd.Uniform(12));
58 }
59 break;
60 }
61 case 1:
62 name = "W1";
63 break;
64 case 2:
65 name = "t0003";
66 break;
67 case 3:
68 name = "weights";
69 break;
70 case 4:
71 name = "weights:17";
72 break;
73 case 5:
74 name = "^weights";
75 break;
76 default:
77 LOG(FATAL) << "Unexpected arg";
78 break;
79 }
80 names.push_back(name);
81 }
82
83 TensorId id;
84 int index = 0;
85 int sum = 0;
86 for (auto s : state) {
87 id = ParseTensorName(names[index++ % names.size()]);
88 sum += id.second;
89 }
90 VLOG(2) << sum; // Prevent compiler from eliminating loop body
91 }
92 BENCHMARK(BM_ParseTensorName)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4)->Arg(5);
93
TEST(TensorIdTest,IsTensorIdControl)94 TEST(TensorIdTest, IsTensorIdControl) {
95 string input = "^foo";
96 TensorId tensor_id = ParseTensorName(input);
97 EXPECT_TRUE(IsTensorIdControl(tensor_id));
98
99 input = "foo";
100 tensor_id = ParseTensorName(input);
101 EXPECT_FALSE(IsTensorIdControl(tensor_id));
102
103 input = "foo:2";
104 tensor_id = ParseTensorName(input);
105 EXPECT_FALSE(IsTensorIdControl(tensor_id));
106 }
107
TEST(TensorIdTest,PortZero)108 TEST(TensorIdTest, PortZero) {
109 for (string input : {"foo", "foo:0"}) {
110 TensorId tensor_id = ParseTensorName(input);
111 EXPECT_EQ("foo", tensor_id.node());
112 EXPECT_EQ(0, tensor_id.index());
113 }
114 }
115
116 } // namespace
117 } // namespace tensorflow
118