• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/tensor_id.h"
17 #include <vector>
18 #include "tensorflow/core/lib/random/simple_philox.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21 #include "tensorflow/core/platform/test_benchmark.h"
22 
23 namespace tensorflow {
24 namespace {
25 
ParseHelper(const string & n)26 string ParseHelper(const string& n) { return ParseTensorName(n).ToString(); }
27 
TEST(TensorIdTest,ParseTensorName)28 TEST(TensorIdTest, ParseTensorName) {
29   EXPECT_EQ(ParseHelper("W1"), "W1:0");
30   EXPECT_EQ(ParseHelper("W1:0"), "W1:0");
31   EXPECT_EQ(ParseHelper("weights:0"), "weights:0");
32   EXPECT_EQ(ParseHelper("W1:1"), "W1:1");
33   EXPECT_EQ(ParseHelper("W1:17"), "W1:17");
34   EXPECT_EQ(ParseHelper("xyz1_17"), "xyz1_17:0");
35   EXPECT_EQ(ParseHelper("^foo"), "^foo");
36 }
37 
Skewed(random::SimplePhilox * rnd,int max_log)38 uint32 Skewed(random::SimplePhilox* rnd, int max_log) {
39   const uint32 space = 1 << (rnd->Rand32() % (max_log + 1));
40   return rnd->Rand32() % space;
41 }
42 
BM_ParseTensorName(::testing::benchmark::State & state)43 void BM_ParseTensorName(::testing::benchmark::State& state) {
44   const int arg = state.range(0);
45   random::PhiloxRandom philox(301, 17);
46   random::SimplePhilox rnd(&philox);
47   std::vector<string> names;
48   for (int i = 0; i < 100; i++) {
49     string name;
50     switch (arg) {
51       case 0: {  // Generate random names
52         size_t len = Skewed(&rnd, 4);
53         while (name.size() < len) {
54           name += rnd.OneIn(4) ? '0' : 'a';
55         }
56         if (rnd.OneIn(3)) {
57           strings::StrAppend(&name, ":", rnd.Uniform(12));
58         }
59         break;
60       }
61       case 1:
62         name = "W1";
63         break;
64       case 2:
65         name = "t0003";
66         break;
67       case 3:
68         name = "weights";
69         break;
70       case 4:
71         name = "weights:17";
72         break;
73       case 5:
74         name = "^weights";
75         break;
76       default:
77         LOG(FATAL) << "Unexpected arg";
78         break;
79     }
80     names.push_back(name);
81   }
82 
83   TensorId id;
84   int index = 0;
85   int sum = 0;
86   for (auto s : state) {
87     id = ParseTensorName(names[index++ % names.size()]);
88     sum += id.second;
89   }
90   VLOG(2) << sum;  // Prevent compiler from eliminating loop body
91 }
92 BENCHMARK(BM_ParseTensorName)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4)->Arg(5);
93 
TEST(TensorIdTest,IsTensorIdControl)94 TEST(TensorIdTest, IsTensorIdControl) {
95   string input = "^foo";
96   TensorId tensor_id = ParseTensorName(input);
97   EXPECT_TRUE(IsTensorIdControl(tensor_id));
98 
99   input = "foo";
100   tensor_id = ParseTensorName(input);
101   EXPECT_FALSE(IsTensorIdControl(tensor_id));
102 
103   input = "foo:2";
104   tensor_id = ParseTensorName(input);
105   EXPECT_FALSE(IsTensorIdControl(tensor_id));
106 }
107 
TEST(TensorIdTest,PortZero)108 TEST(TensorIdTest, PortZero) {
109   for (string input : {"foo", "foo:0"}) {
110     TensorId tensor_id = ParseTensorName(input);
111     EXPECT_EQ("foo", tensor_id.node());
112     EXPECT_EQ(0, tensor_id.index());
113   }
114 }
115 
116 }  // namespace
117 }  // namespace tensorflow
118