1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "SocketProfilingConnection.hpp"
7
8 #include "common/include/SocketConnectionException.hpp"
9
10 #include <cerrno>
11 #include <fcntl.h>
12 #include <string>
13
14
15 namespace armnn
16 {
17 namespace profiling
18 {
19
SocketProfilingConnection()20 SocketProfilingConnection::SocketProfilingConnection()
21 {
22 arm::pipe::Initialize();
23 memset(m_Socket, 0, sizeof(m_Socket));
24 // Note: we're using Linux specific SOCK_CLOEXEC flag.
25 m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
26 if (m_Socket[0].fd == -1)
27 {
28 throw arm::pipe::SocketConnectionException(
29 std::string("SocketProfilingConnection: Socket construction failed: ") + strerror(errno),
30 m_Socket[0].fd,
31 errno);
32 }
33
34 // Connect to the named unix domain socket.
35 sockaddr_un server{};
36 memset(&server, 0, sizeof(sockaddr_un));
37 // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
38 memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
39 server.sun_family = AF_UNIX;
40 if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
41 {
42 Close();
43 throw arm::pipe::SocketConnectionException(
44 std::string("SocketProfilingConnection: Cannot connect to stream socket: ") + strerror(errno),
45 m_Socket[0].fd,
46 errno);
47 }
48
49 // Our socket will only be interested in polling reads.
50 m_Socket[0].events = POLLIN;
51
52 // Make the socket non blocking.
53 if (!arm::pipe::SetNonBlocking(m_Socket[0].fd))
54 {
55 Close();
56 throw arm::pipe::SocketConnectionException(
57 std::string("SocketProfilingConnection: Failed to set socket as non blocking: ") + strerror(errno),
58 m_Socket[0].fd,
59 errno);
60 }
61 }
62
IsOpen() const63 bool SocketProfilingConnection::IsOpen() const
64 {
65 return m_Socket[0].fd > 0;
66 }
67
Close()68 void SocketProfilingConnection::Close()
69 {
70 if (arm::pipe::Close(m_Socket[0].fd) != 0)
71 {
72 throw arm::pipe::SocketConnectionException(
73 std::string("SocketProfilingConnection: Cannot close stream socket: ") + strerror(errno),
74 m_Socket[0].fd,
75 errno);
76 }
77
78 memset(m_Socket, 0, sizeof(m_Socket));
79 }
80
WritePacket(const unsigned char * buffer,uint32_t length)81 bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length)
82 {
83 if (buffer == nullptr || length == 0)
84 {
85 return false;
86 }
87
88 return arm::pipe::Write(m_Socket[0].fd, buffer, length) != -1;
89 }
90
ReadPacket(uint32_t timeout)91 arm::pipe::Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
92 {
93 // Is there currently at least a header worth of data waiting to be read?
94 int bytes_available = 0;
95 arm::pipe::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
96 if (bytes_available >= 8)
97 {
98 // Yes there is. Read it:
99 return ReceivePacket();
100 }
101
102 // Poll for data on the socket or until timeout occurs
103 int pollResult = arm::pipe::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
104
105 switch (pollResult)
106 {
107 case -1: // Error
108 throw arm::pipe::SocketConnectionException(
109 std::string("SocketProfilingConnection: Error occured while reading from socket: ") + strerror(errno),
110 m_Socket[0].fd,
111 errno);
112
113 case 0: // Timeout
114 throw arm::pipe::TimeoutException("SocketProfilingConnection: Timeout while reading from socket");
115
116 default: // Normal poll return but it could still contain an error signal
117 // Check if the socket reported an error
118 if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
119 {
120 if (m_Socket[0].revents == POLLNVAL)
121 {
122 // This is an unrecoverable error.
123 Close();
124 throw arm::pipe::SocketConnectionException(
125 std::string("SocketProfilingConnection: Error occured while polling receiving socket: POLLNVAL."),
126 m_Socket[0].fd);
127 }
128 if (m_Socket[0].revents == POLLERR)
129 {
130 throw arm::pipe::SocketConnectionException(
131 std::string(
132 "SocketProfilingConnection: Error occured while polling receiving socket: POLLERR: ")
133 + strerror(errno),
134 m_Socket[0].fd,
135 errno);
136 }
137 if (m_Socket[0].revents == POLLHUP)
138 {
139 // This is an unrecoverable error.
140 Close();
141 throw arm::pipe::SocketConnectionException(
142 std::string("SocketProfilingConnection: Connection closed by remote client: POLLHUP."),
143 m_Socket[0].fd);
144 }
145 }
146
147 // Check if there is data to read
148 if (!(m_Socket[0].revents & (POLLIN)))
149 {
150 // This is a corner case. The socket as been woken up but not with any data.
151 // We'll throw a timeout exception to loop around again.
152 throw armnn::TimeoutException(
153 "SocketProfilingConnection: File descriptor was polled but no data was available to receive.");
154 }
155
156 return ReceivePacket();
157 }
158 }
159
ReceivePacket()160 arm::pipe::Packet SocketProfilingConnection::ReceivePacket()
161 {
162 char header[8] = {};
163 long receiveResult = arm::pipe::Read(m_Socket[0].fd, &header, sizeof(header));
164 // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
165 switch( receiveResult )
166 {
167 case 0:
168 // Socket has closed.
169 Close();
170 throw arm::pipe::SocketConnectionException(
171 std::string("SocketProfilingConnection: Remote socket has closed the connection."),
172 m_Socket[0].fd);
173 case -1:
174 // There's been a socket error. We will presume it's unrecoverable.
175 Close();
176 throw arm::pipe::SocketConnectionException(
177 std::string("SocketProfilingConnection: Error occured while reading the packet: ") + strerror(errno),
178 m_Socket[0].fd,
179 errno);
180 default:
181 if (receiveResult < 8)
182 {
183 throw arm::pipe::SocketConnectionException(
184 std::string(
185 "SocketProfilingConnection: The received packet did not contains a valid PIPE header."),
186 m_Socket[0].fd);
187 }
188 break;
189 }
190
191 // stream_metadata_identifier is the first 4 bytes
192 uint32_t metadataIdentifier = 0;
193 std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier));
194
195 // data_length is the next 4 bytes
196 uint32_t dataLength = 0;
197 std::memcpy(&dataLength, header + 4u, sizeof(dataLength));
198
199 std::unique_ptr<unsigned char[]> packetData;
200 if (dataLength > 0)
201 {
202 packetData = std::make_unique<unsigned char[]>(dataLength);
203 long receivedLength = arm::pipe::Read(m_Socket[0].fd, packetData.get(), dataLength);
204 if (receivedLength < 0)
205 {
206 throw arm::pipe::SocketConnectionException(
207 std::string("SocketProfilingConnection: Error occured while reading the packet: ") + strerror(errno),
208 m_Socket[0].fd,
209 errno);
210 }
211 if (dataLength != static_cast<uint32_t>(receivedLength))
212 {
213 // What do we do here if we can't read in a full packet?
214 throw arm::pipe::SocketConnectionException(
215 std::string("SocketProfilingConnection: Invalid PIPE packet."),
216 m_Socket[0].fd);
217 }
218 }
219
220 return arm::pipe::Packet(metadataIdentifier, dataLength, packetData);
221 }
222
223 } // namespace profiling
224 } // namespace armnn
225