• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2024 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/impl/grpc_types.h>
20 #include <gtest/gtest.h>
21 
22 #include "src/core/client_channel/client_channel_filter.h"
23 #include "src/core/lib/security/context/security_context.h"
24 #include "src/core/tsi/transport_security.h"
25 #include "test/core/test_util/test_config.h"
26 
27 namespace grpc_core {
28 namespace testing {
29 namespace {
30 
GetLocalUnixAddress(grpc_endpoint *)31 absl::string_view GetLocalUnixAddress(grpc_endpoint* /*ep*/) { return "unix:"; }
32 
33 const grpc_endpoint_vtable kUnixEndpointVtable = {
34     nullptr, nullptr, nullptr, nullptr,
35     nullptr, nullptr, nullptr, GetLocalUnixAddress,
36     nullptr, nullptr};
37 
GetLocalTcpAddress(grpc_endpoint *)38 absl::string_view GetLocalTcpAddress(grpc_endpoint* /*ep*/) {
39   return "ipv4:127.0.0.1:12667";
40 }
41 
42 const grpc_endpoint_vtable kTcpEndpointVtable = {
43     nullptr, nullptr, nullptr, nullptr,
44     nullptr, nullptr, nullptr, GetLocalTcpAddress,
45     nullptr, nullptr};
46 
GetSecurityLevelForServer(grpc_local_connect_type connect_type,grpc_endpoint & ep)47 std::string GetSecurityLevelForServer(grpc_local_connect_type connect_type,
48                                       grpc_endpoint& ep) {
49   grpc_server_credentials* server_creds =
50       grpc_local_server_credentials_create(connect_type);
51   ChannelArgs args;
52   RefCountedPtr<grpc_server_security_connector> connector =
53       server_creds->create_security_connector(args);
54   tsi_peer peer;
55   CHECK(tsi_construct_peer(0, &peer) == TSI_OK);
56 
57   RefCountedPtr<grpc_auth_context> auth_context;
58   connector->check_peer(peer, &ep, args, &auth_context, nullptr);
59   tsi_peer_destruct(&peer);
60   auto it = grpc_auth_context_find_properties_by_name(
61       auth_context.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
62   const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
63   std::string actual_level;
64   if (prop != nullptr) {
65     actual_level = std::string(prop->value, prop->value_length);
66   }
67   connector.reset();
68   auth_context.reset();
69   grpc_server_credentials_release(server_creds);
70   return actual_level;
71 }
72 
GetSecurityLevelForChannel(grpc_local_connect_type connect_type,grpc_endpoint & ep)73 std::string GetSecurityLevelForChannel(grpc_local_connect_type connect_type,
74                                        grpc_endpoint& ep) {
75   grpc_channel_credentials* channel_creds =
76       grpc_local_credentials_create(connect_type);
77   ChannelArgs args;
78   args = args.Set((char*)GRPC_ARG_SERVER_URI, (char*)"unix:");
79   RefCountedPtr<grpc_channel_security_connector> connector =
80       channel_creds->create_security_connector(nullptr, "unix:", &args);
81   tsi_peer peer;
82   CHECK(tsi_construct_peer(0, &peer) == TSI_OK);
83   RefCountedPtr<grpc_auth_context> auth_context;
84   connector->check_peer(peer, &ep, args, &auth_context, nullptr);
85   tsi_peer_destruct(&peer);
86   auto it = grpc_auth_context_find_properties_by_name(
87       auth_context.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
88   const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
89   std::string actual_level;
90   if (prop != nullptr) {
91     actual_level = std::string(prop->value, prop->value_length);
92   }
93   connector.reset();
94   auth_context.reset();
95   grpc_channel_credentials_release(channel_creds);
96   return actual_level;
97 }
98 
TEST(LocalSecurityConnectorTest,CheckSecurityLevelOfUdsConnectionServer)99 TEST(LocalSecurityConnectorTest, CheckSecurityLevelOfUdsConnectionServer) {
100   grpc_endpoint ep;
101   ep.vtable = &kUnixEndpointVtable;
102   std::string actual_level = GetSecurityLevelForServer(UDS, ep);
103   ASSERT_EQ(actual_level,
104             tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY));
105 }
106 
TEST(LocalSecurityConnectorTest,SecurityLevelOfTcpConnectionServer)107 TEST(LocalSecurityConnectorTest, SecurityLevelOfTcpConnectionServer) {
108   grpc_endpoint ep;
109   ep.vtable = &kTcpEndpointVtable;
110   std::string actual_level = GetSecurityLevelForServer(LOCAL_TCP, ep);
111   ASSERT_EQ(actual_level,
112             IsLocalConnectorSecureEnabled()
113                 ? tsi_security_level_to_string(TSI_SECURITY_NONE)
114                 : tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY));
115 }
116 
TEST(LocalSecurityConnectorTest,CheckSecurityLevelOfUdsConnectionChannel)117 TEST(LocalSecurityConnectorTest, CheckSecurityLevelOfUdsConnectionChannel) {
118   grpc_endpoint ep;
119   ep.vtable = &kUnixEndpointVtable;
120   std::string actual_level = GetSecurityLevelForChannel(UDS, ep);
121   ASSERT_EQ(actual_level,
122             tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY));
123 }
124 
TEST(LocalSecurityConnectorTest,SecurityLevelOfTcpConnectionChannel)125 TEST(LocalSecurityConnectorTest, SecurityLevelOfTcpConnectionChannel) {
126   grpc_endpoint ep;
127   ep.vtable = &kTcpEndpointVtable;
128   std::string actual_level = GetSecurityLevelForChannel(LOCAL_TCP, ep);
129   ASSERT_EQ(actual_level,
130             IsLocalConnectorSecureEnabled()
131                 ? tsi_security_level_to_string(TSI_SECURITY_NONE)
132                 : tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY));
133 }
134 
135 }  // namespace
136 }  // namespace testing
137 }  // namespace grpc_core
138 
main(int argc,char ** argv)139 int main(int argc, char** argv) {
140   grpc::testing::TestEnvironment env(&argc, argv);
141   ::testing::InitGoogleTest(&argc, argv);
142   grpc_init();
143   int ret = RUN_ALL_TESTS();
144   grpc_shutdown();
145   return ret;
146 }
147