1 /* Copyright 2019 The Chromium OS Authors. All rights reserved.
2 * Use of this source code is governed by a BSD-style license that can be
3 * found in the LICENSE file.
4 */
5
6 #include <syslog.h>
7
8 #include "cras_iodev_list.h"
9 #include "cras_messages.h"
10 #include "cras_observer.h"
11 #include "cras_rclient.h"
12 #include "cras_rclient_util.h"
13 #include "cras_rstream.h"
14 #include "cras_server_metrics.h"
15 #include "cras_system_state.h"
16 #include "cras_tm.h"
17 #include "cras_types.h"
18 #include "cras_util.h"
19 #include "stream_list.h"
20
rclient_send_message_to_client(const struct cras_rclient * client,const struct cras_client_message * msg,int * fds,unsigned int num_fds)21 int rclient_send_message_to_client(const struct cras_rclient *client,
22 const struct cras_client_message *msg,
23 int *fds, unsigned int num_fds)
24 {
25 return cras_send_with_fds(client->fd, (const void *)msg, msg->length,
26 fds, num_fds);
27 }
28
rclient_destroy(struct cras_rclient * client)29 void rclient_destroy(struct cras_rclient *client)
30 {
31 cras_observer_remove(client->observer);
32 stream_list_rm_all_client_streams(cras_iodev_list_get_stream_list(),
33 client);
34 free(client);
35 }
36
rclient_validate_message_fds(const struct cras_server_message * msg,int * fds,unsigned int num_fds)37 int rclient_validate_message_fds(const struct cras_server_message *msg,
38 int *fds, unsigned int num_fds)
39 {
40 switch (msg->id) {
41 case CRAS_SERVER_CONNECT_STREAM:
42 if (num_fds > 2)
43 goto error;
44 break;
45 case CRAS_SERVER_SET_AEC_DUMP:
46 if (num_fds > 1)
47 goto error;
48 break;
49 default:
50 if (num_fds > 0)
51 goto error;
52 break;
53 }
54
55 return 0;
56
57 error:
58 syslog(LOG_ERR, "Message %d should not have %u fds attached.", msg->id,
59 num_fds);
60 return -EINVAL;
61 }
62
63 static int
rclient_validate_stream_connect_message(const struct cras_rclient * client,const struct cras_connect_message * msg)64 rclient_validate_stream_connect_message(const struct cras_rclient *client,
65 const struct cras_connect_message *msg)
66 {
67 if (!cras_valid_stream_id(msg->stream_id, client->id)) {
68 syslog(LOG_ERR,
69 "stream_connect: invalid stream_id: %x for "
70 "client: %zx.\n",
71 msg->stream_id, client->id);
72 return -EINVAL;
73 }
74
75 int direction = cras_stream_direction_mask(msg->direction);
76 if (direction < 0 || !(client->supported_directions & direction)) {
77 syslog(LOG_ERR,
78 "stream_connect: invalid stream direction: %x for "
79 "client: %zx.\n",
80 msg->direction, client->id);
81 return -EINVAL;
82 }
83
84 if (!cras_validate_client_type(msg->client_type)) {
85 syslog(LOG_ERR,
86 "stream_connect: invalid stream client_type: %x for "
87 "client: %zx.\n",
88 msg->client_type, client->id);
89 }
90 return 0;
91 }
92
rclient_validate_stream_connect_fds(int audio_fd,int client_shm_fd,size_t client_shm_size)93 static int rclient_validate_stream_connect_fds(int audio_fd, int client_shm_fd,
94 size_t client_shm_size)
95 {
96 /* check audio_fd is valid. */
97 if (audio_fd < 0) {
98 syslog(LOG_ERR, "Invalid audio fd in stream connect.\n");
99 return -EBADF;
100 }
101
102 /* check client_shm_fd is valid if client wants to use client shm. */
103 if (client_shm_size > 0 && client_shm_fd < 0) {
104 syslog(LOG_ERR,
105 "client_shm_fd must be valid if client_shm_size > 0.\n");
106 return -EBADF;
107 } else if (client_shm_size == 0 && client_shm_fd >= 0) {
108 syslog(LOG_ERR,
109 "client_shm_fd can be valid only if client_shm_size > 0.\n");
110 return -EINVAL;
111 }
112 return 0;
113 }
114
rclient_validate_stream_connect_params(const struct cras_rclient * client,const struct cras_connect_message * msg,int audio_fd,int client_shm_fd)115 int rclient_validate_stream_connect_params(
116 const struct cras_rclient *client,
117 const struct cras_connect_message *msg, int audio_fd, int client_shm_fd)
118 {
119 int rc;
120
121 rc = rclient_validate_stream_connect_message(client, msg);
122 if (rc)
123 return rc;
124
125 rc = rclient_validate_stream_connect_fds(audio_fd, client_shm_fd,
126 msg->client_shm_size);
127 if (rc)
128 return rc;
129
130 return 0;
131 }
132
rclient_handle_client_stream_connect(struct cras_rclient * client,const struct cras_connect_message * msg,int aud_fd,int client_shm_fd)133 int rclient_handle_client_stream_connect(struct cras_rclient *client,
134 const struct cras_connect_message *msg,
135 int aud_fd, int client_shm_fd)
136 {
137 struct cras_rstream *stream;
138 struct cras_client_stream_connected stream_connected;
139 struct cras_client_message *reply;
140 struct cras_audio_format remote_fmt;
141 struct cras_rstream_config stream_config;
142 int rc, header_fd, samples_fd;
143 size_t samples_size;
144 int stream_fds[2];
145
146 rc = rclient_validate_stream_connect_params(client, msg, aud_fd,
147 client_shm_fd);
148 remote_fmt = unpack_cras_audio_format(&msg->format);
149 if (rc == 0 && !cras_audio_format_valid(&remote_fmt)) {
150 rc = -EINVAL;
151 }
152 if (rc) {
153 if (client_shm_fd >= 0)
154 close(client_shm_fd);
155 if (aud_fd >= 0)
156 close(aud_fd);
157 goto reply_err;
158 }
159
160 /* When full, getting an error is preferable to blocking. */
161 cras_make_fd_nonblocking(aud_fd);
162
163 stream_config = cras_rstream_config_init_with_message(
164 client, msg, &aud_fd, &client_shm_fd, &remote_fmt);
165 /* Overwrite client_type if client->client_type is set. */
166 if (client->client_type != CRAS_CLIENT_TYPE_UNKNOWN)
167 stream_config.client_type = client->client_type;
168 rc = stream_list_add(cras_iodev_list_get_stream_list(), &stream_config,
169 &stream);
170 if (rc)
171 goto cleanup_config;
172
173 detect_rtc_stream_pair(cras_iodev_list_get_stream_list(), stream);
174
175 /* Tell client about the stream setup. */
176 syslog(LOG_DEBUG, "Send connected for stream %x\n", msg->stream_id);
177
178 // Check that shm size is at most UINT32_MAX for non-shm streams.
179 samples_size = cras_rstream_get_samples_shm_size(stream);
180 if (samples_size > UINT32_MAX && stream_config.client_shm_fd < 0) {
181 syslog(LOG_ERR,
182 "Non client-provided shm stream has samples shm larger "
183 "than uint32_t: %zu",
184 samples_size);
185 if (aud_fd >= 0)
186 close(aud_fd);
187 rc = -EINVAL;
188 goto cleanup_config;
189 }
190 cras_fill_client_stream_connected(&stream_connected, 0, /* No error. */
191 msg->stream_id, &remote_fmt,
192 samples_size,
193 cras_rstream_get_effects(stream));
194 reply = &stream_connected.header;
195
196 rc = cras_rstream_get_shm_fds(stream, &header_fd, &samples_fd);
197 if (rc)
198 goto cleanup_config;
199
200 stream_fds[0] = header_fd;
201 /* If we're using client-provided shm, samples_fd here refers to the
202 * same shm area as client_shm_fd */
203 stream_fds[1] = samples_fd;
204
205 rc = client->ops->send_message_to_client(client, reply, stream_fds, 2);
206 if (rc < 0) {
207 syslog(LOG_ERR, "Failed to send connected messaged\n");
208 stream_list_rm(cras_iodev_list_get_stream_list(),
209 stream->stream_id);
210 goto cleanup_config;
211 }
212
213 /* Cleanup local object explicitly. */
214 cras_rstream_config_cleanup(&stream_config);
215 return 0;
216
217 cleanup_config:
218 cras_rstream_config_cleanup(&stream_config);
219
220 reply_err:
221 /* Send the error code to the client. */
222 cras_fill_client_stream_connected(&stream_connected, rc, msg->stream_id,
223 &remote_fmt, 0, msg->effects);
224 reply = &stream_connected.header;
225 client->ops->send_message_to_client(client, reply, NULL, 0);
226
227 return rc;
228 }
229
230 /* Handles messages from the client requesting that a stream be removed from the
231 * server. */
rclient_handle_client_stream_disconnect(struct cras_rclient * client,const struct cras_disconnect_stream_message * msg)232 int rclient_handle_client_stream_disconnect(
233 struct cras_rclient *client,
234 const struct cras_disconnect_stream_message *msg)
235 {
236 if (!cras_valid_stream_id(msg->stream_id, client->id)) {
237 syslog(LOG_ERR,
238 "stream_disconnect: invalid stream_id: %x for "
239 "client: %zx.\n",
240 msg->stream_id, client->id);
241 return -EINVAL;
242 }
243 return stream_list_rm(cras_iodev_list_get_stream_list(),
244 msg->stream_id);
245 }
246
247 /* Creates a client structure and sends a message back informing the client that
248 * the connection has succeeded. */
rclient_generic_create(int fd,size_t id,const struct cras_rclient_ops * ops,int supported_directions)249 struct cras_rclient *rclient_generic_create(int fd, size_t id,
250 const struct cras_rclient_ops *ops,
251 int supported_directions)
252 {
253 struct cras_rclient *client;
254 struct cras_client_connected msg;
255 int state_fd;
256
257 client = (struct cras_rclient *)calloc(1, sizeof(struct cras_rclient));
258 if (!client)
259 return NULL;
260
261 client->fd = fd;
262 client->id = id;
263 client->ops = ops;
264 client->supported_directions = supported_directions;
265
266 cras_fill_client_connected(&msg, client->id);
267 state_fd = cras_sys_state_shm_fd();
268 client->ops->send_message_to_client(client, &msg.header, &state_fd, 1);
269
270 return client;
271 }
272
273 /* A generic entry point for handling a message from the client. Called from
274 * the main server context. */
rclient_handle_message_from_client(struct cras_rclient * client,const struct cras_server_message * msg,int * fds,unsigned int num_fds)275 int rclient_handle_message_from_client(struct cras_rclient *client,
276 const struct cras_server_message *msg,
277 int *fds, unsigned int num_fds)
278 {
279 int rc = 0;
280 assert(client && msg);
281
282 rc = rclient_validate_message_fds(msg, fds, num_fds);
283 if (rc < 0) {
284 for (int i = 0; i < (int)num_fds; i++)
285 if (fds[i] >= 0)
286 close(fds[i]);
287 return rc;
288 }
289 int fd = num_fds > 0 ? fds[0] : -1;
290
291 switch (msg->id) {
292 case CRAS_SERVER_CONNECT_STREAM: {
293 int client_shm_fd = num_fds > 1 ? fds[1] : -1;
294 if (MSG_LEN_VALID(msg, struct cras_connect_message)) {
295 rclient_handle_client_stream_connect(
296 client,
297 (const struct cras_connect_message *)msg, fd,
298 client_shm_fd);
299 } else {
300 return -EINVAL;
301 }
302 break;
303 }
304 case CRAS_SERVER_DISCONNECT_STREAM:
305 if (!MSG_LEN_VALID(msg, struct cras_disconnect_stream_message))
306 return -EINVAL;
307 rclient_handle_client_stream_disconnect(
308 client,
309 (const struct cras_disconnect_stream_message *)msg);
310 break;
311 default:
312 break;
313 }
314
315 return rc;
316 }
317