• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "SocketProfilingConnection.hpp"
7 
8 #include "common/include/SocketConnectionException.hpp"
9 
10 #include <cerrno>
11 #include <fcntl.h>
12 #include <string>
13 
14 
15 namespace armnn
16 {
17 namespace profiling
18 {
19 
SocketProfilingConnection()20 SocketProfilingConnection::SocketProfilingConnection()
21 {
22     arm::pipe::Initialize();
23     memset(m_Socket, 0, sizeof(m_Socket));
24     // Note: we're using Linux specific SOCK_CLOEXEC flag.
25     m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
26     if (m_Socket[0].fd == -1)
27     {
28         throw arm::pipe::SocketConnectionException(
29             std::string("SocketProfilingConnection: Socket construction failed: ")  + strerror(errno),
30             m_Socket[0].fd,
31             errno);
32     }
33 
34     // Connect to the named unix domain socket.
35     sockaddr_un server{};
36     memset(&server, 0, sizeof(sockaddr_un));
37     // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
38     memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
39     server.sun_family = AF_UNIX;
40     if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
41     {
42         Close();
43         throw arm::pipe::SocketConnectionException(
44             std::string("SocketProfilingConnection: Cannot connect to stream socket: ")  + strerror(errno),
45             m_Socket[0].fd,
46             errno);
47     }
48 
49     // Our socket will only be interested in polling reads.
50     m_Socket[0].events = POLLIN;
51 
52     // Make the socket non blocking.
53     if (!arm::pipe::SetNonBlocking(m_Socket[0].fd))
54     {
55         Close();
56         throw arm::pipe::SocketConnectionException(
57             std::string("SocketProfilingConnection: Failed to set socket as non blocking: ")  + strerror(errno),
58             m_Socket[0].fd,
59             errno);
60     }
61 }
62 
IsOpen() const63 bool SocketProfilingConnection::IsOpen() const
64 {
65     return m_Socket[0].fd > 0;
66 }
67 
Close()68 void SocketProfilingConnection::Close()
69 {
70     if (arm::pipe::Close(m_Socket[0].fd) != 0)
71     {
72         throw arm::pipe::SocketConnectionException(
73             std::string("SocketProfilingConnection: Cannot close stream socket: ")  + strerror(errno),
74             m_Socket[0].fd,
75             errno);
76     }
77 
78     memset(m_Socket, 0, sizeof(m_Socket));
79 }
80 
WritePacket(const unsigned char * buffer,uint32_t length)81 bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length)
82 {
83     if (buffer == nullptr || length == 0)
84     {
85         return false;
86     }
87 
88     return arm::pipe::Write(m_Socket[0].fd, buffer, length) != -1;
89 }
90 
ReadPacket(uint32_t timeout)91 arm::pipe::Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
92 {
93     // Is there currently at least a header worth of data waiting to be read?
94     int bytes_available = 0;
95     arm::pipe::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
96     if (bytes_available >= 8)
97     {
98         // Yes there is. Read it:
99         return ReceivePacket();
100     }
101 
102     // Poll for data on the socket or until timeout occurs
103     int pollResult = arm::pipe::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
104 
105     switch (pollResult)
106     {
107     case -1: // Error
108         throw arm::pipe::SocketConnectionException(
109             std::string("SocketProfilingConnection: Error occured while reading from socket: ") + strerror(errno),
110             m_Socket[0].fd,
111             errno);
112 
113     case 0: // Timeout
114         throw arm::pipe::TimeoutException("SocketProfilingConnection: Timeout while reading from socket");
115 
116     default: // Normal poll return but it could still contain an error signal
117         // Check if the socket reported an error
118         if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
119         {
120             if (m_Socket[0].revents == POLLNVAL)
121             {
122                 // This is an unrecoverable error.
123                 Close();
124                 throw arm::pipe::SocketConnectionException(
125                     std::string("SocketProfilingConnection: Error occured while polling receiving socket: POLLNVAL."),
126                     m_Socket[0].fd);
127             }
128             if (m_Socket[0].revents == POLLERR)
129             {
130                 throw arm::pipe::SocketConnectionException(
131                     std::string(
132                         "SocketProfilingConnection: Error occured while polling receiving socket: POLLERR: ")
133                         + strerror(errno),
134                     m_Socket[0].fd,
135                     errno);
136             }
137             if (m_Socket[0].revents == POLLHUP)
138             {
139                 // This is an unrecoverable error.
140                 Close();
141                 throw arm::pipe::SocketConnectionException(
142                     std::string("SocketProfilingConnection: Connection closed by remote client: POLLHUP."),
143                     m_Socket[0].fd);
144             }
145         }
146 
147         // Check if there is data to read
148         if (!(m_Socket[0].revents & (POLLIN)))
149         {
150             // This is a corner case. The socket as been woken up but not with any data.
151             // We'll throw a timeout exception to loop around again.
152             throw armnn::TimeoutException(
153                 "SocketProfilingConnection: File descriptor was polled but no data was available to receive.");
154         }
155 
156         return ReceivePacket();
157     }
158 }
159 
ReceivePacket()160 arm::pipe::Packet SocketProfilingConnection::ReceivePacket()
161 {
162     char header[8] = {};
163     long receiveResult = arm::pipe::Read(m_Socket[0].fd, &header, sizeof(header));
164     // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
165     switch( receiveResult )
166     {
167         case 0:
168             // Socket has closed.
169             Close();
170             throw arm::pipe::SocketConnectionException(
171                 std::string("SocketProfilingConnection: Remote socket has closed the connection."),
172                 m_Socket[0].fd);
173         case -1:
174             // There's been a socket error. We will presume it's unrecoverable.
175             Close();
176             throw arm::pipe::SocketConnectionException(
177                 std::string("SocketProfilingConnection: Error occured while reading the packet: ") + strerror(errno),
178                 m_Socket[0].fd,
179                 errno);
180         default:
181             if (receiveResult < 8)
182             {
183                  throw arm::pipe::SocketConnectionException(
184                      std::string(
185                          "SocketProfilingConnection: The received packet did not contains a valid PIPE header."),
186                      m_Socket[0].fd);
187             }
188             break;
189     }
190 
191     // stream_metadata_identifier is the first 4 bytes
192     uint32_t metadataIdentifier = 0;
193     std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier));
194 
195     // data_length is the next 4 bytes
196     uint32_t dataLength = 0;
197     std::memcpy(&dataLength, header + 4u, sizeof(dataLength));
198 
199     std::unique_ptr<unsigned char[]> packetData;
200     if (dataLength > 0)
201     {
202         packetData = std::make_unique<unsigned char[]>(dataLength);
203         long receivedLength = arm::pipe::Read(m_Socket[0].fd, packetData.get(), dataLength);
204         if (receivedLength < 0)
205         {
206             throw arm::pipe::SocketConnectionException(
207                 std::string("SocketProfilingConnection: Error occured while reading the packet: ")  + strerror(errno),
208                 m_Socket[0].fd,
209                 errno);
210         }
211         if (dataLength != static_cast<uint32_t>(receivedLength))
212         {
213             // What do we do here if we can't read in a full packet?
214             throw arm::pipe::SocketConnectionException(
215                 std::string("SocketProfilingConnection: Invalid PIPE packet."),
216                 m_Socket[0].fd);
217         }
218     }
219 
220     return arm::pipe::Packet(metadataIdentifier, dataLength, packetData);
221 }
222 
223 } // namespace profiling
224 } // namespace armnn
225