cc2e0a05f4e5b865c59387cd6aae6c2512335d83
[metze/samba/wip.git] / lib / util / tfork.c
1 /*
2    fork on steroids to avoid SIGCHLD and waitpid
3
4    Copyright (C) Stefan Metzmacher 2010
5    Copyright (C) Ralph Boehme 2017
6
7    This program is free software; you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation; either version 3 of the License, or
10    (at your option) any later version.
11
12    This program is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16
17    You should have received a copy of the GNU General Public License
18    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 */
20
21 #include "replace.h"
22 #include "system/wait.h"
23 #include "system/filesys.h"
24 #include "system/network.h"
25 #include "lib/util/samba_util.h"
26 #include "lib/util/sys_rw.h"
27 #include "lib/util/tfork.h"
28 #include "lib/util/debug.h"
29
30 #ifdef HAVE_PTHREAD
31 #include <pthread.h>
32 #endif
33
34 #ifdef NDEBUG
35 #undef NDEBUG
36 #endif
37 #include <assert.h>
38
39 /*
40  * This is how the process hierarchy looks like:
41  *
42  *   +----------+
43  *   |  caller  |
44  *   +----------+
45  *         |
46  *       fork
47  *         |
48  *         v
49  *   +----------+
50  *   |  waiter  |
51  *   +----------+
52  *         |
53  *       fork
54  *         |
55  *         v
56  *   +----------+
57  *   |  worker  |
58  *   +----------+
59  */
60
61 /*
62  * The resulting (private) state per tfork_create() call, returned as a opaque
63  * handle to the caller.
64  */
65 struct tfork {
66         /*
67          * This is returned to the caller with tfork_event_fd()
68          */
69         int event_fd;
70
71         /*
72          * This is used in the caller by tfork_status() to read the worker exit
73          * status and to tell the waiter to exit by closing the fd.
74          */
75         int status_fd;
76
77         pid_t waiter_pid;
78         pid_t worker_pid;
79 };
80
81 /*
82  * Internal per-thread state maintained while inside tfork.
83  */
84 struct tfork_state {
85         pid_t waiter_pid;
86         int waiter_errno;
87
88         pid_t worker_pid;
89 };
90
91 /*
92  * A global state that synchronizes access to handling SIGCHLD and waiting for
93  * childs.
94  */
95 struct tfork_signal_state {
96         bool available;
97
98 #ifdef HAVE_PTHREAD
99         pthread_cond_t cond;
100         pthread_mutex_t mutex;
101 #endif
102
103         /*
104          * pid of the waiter child. This points at waiter_pid in either struct
105          * tfork or struct tfork_state, depending on who called
106          * tfork_install_sigchld_handler().
107          *
108          * When tfork_install_sigchld_handler() is called the waiter_pid is
109          * still -1 and only set later after fork(), that's why this is must be
110          * a pointer. The signal handler checks this.
111          */
112         pid_t *pid;
113
114         struct sigaction oldact;
115         sigset_t oldset;
116 };
117
118 static struct tfork_signal_state signal_state;
119
120 #ifdef HAVE_PTHREAD
121 static pthread_once_t tfork_global_is_initialized = PTHREAD_ONCE_INIT;
122 static pthread_key_t tfork_global_key;
123 #else
124 static struct tfork_state *global_state;
125 #endif
126
127 static void tfork_sigchld_handler(int signum, siginfo_t *si, void *p);
128
129 #ifdef HAVE_PTHREAD
130 static void tfork_global_destructor(void *state)
131 {
132         anonymous_shared_free(state);
133 }
134 #endif
135
136 static int tfork_acquire_sighandling(void)
137 {
138         int ret = 0;
139
140 #ifdef HAVE_PTHREAD
141         ret = pthread_mutex_lock(&signal_state.mutex);
142         if (ret != 0) {
143                 return ret;
144         }
145
146         while (!signal_state.available) {
147                 ret = pthread_cond_wait(&signal_state.cond,
148                                         &signal_state.mutex);
149                 if (ret != 0) {
150                         return ret;
151                 }
152         }
153
154         signal_state.available = false;
155
156         ret = pthread_mutex_unlock(&signal_state.mutex);
157         if (ret != 0) {
158                 return ret;
159         }
160 #endif
161
162         return ret;
163 }
164
165 static int tfork_release_sighandling(void)
166 {
167         int ret = 0;
168
169 #ifdef HAVE_PTHREAD
170         ret = pthread_mutex_lock(&signal_state.mutex);
171         if (ret != 0) {
172                 return ret;
173         }
174
175         signal_state.available = true;
176
177         ret = pthread_cond_signal(&signal_state.cond);
178         if (ret != 0) {
179                 pthread_mutex_unlock(&signal_state.mutex);
180                 return ret;
181         }
182
183         ret = pthread_mutex_unlock(&signal_state.mutex);
184         if (ret != 0) {
185                 return ret;
186         }
187 #endif
188
189         return ret;
190 }
191
192 #ifdef HAVE_PTHREAD
193 static void tfork_atfork_prepare(void)
194 {
195         int ret;
196
197         ret = pthread_mutex_lock(&signal_state.mutex);
198         assert(ret == 0);
199 }
200
201 static void tfork_atfork_parent(void)
202 {
203         int ret;
204
205         ret = pthread_mutex_unlock(&signal_state.mutex);
206         assert(ret == 0);
207 }
208 #endif
209
210 static void tfork_atfork_child(void)
211 {
212         int ret;
213
214 #ifdef HAVE_PTHREAD
215         ret = pthread_mutex_unlock(&signal_state.mutex);
216         assert(ret == 0);
217
218         ret = pthread_key_delete(tfork_global_key);
219         assert(ret == 0);
220
221         ret = pthread_key_create(&tfork_global_key, tfork_global_destructor);
222         assert(ret == 0);
223
224         /*
225          * There's no way to destroy a condition variable if there are waiters,
226          * pthread_cond_destroy() will return EBUSY. Just zero out memory and
227          * then initialize again. This is not backed by POSIX but should be ok.
228          */
229         ZERO_STRUCT(signal_state.cond);
230         ret = pthread_cond_init(&signal_state.cond, NULL);
231         assert(ret == 0);
232 #endif
233
234         if (signal_state.pid != NULL) {
235
236                 ret = sigaction(SIGCHLD, &signal_state.oldact, NULL);
237                 assert(ret == 0);
238
239 #ifdef HAVE_PTHREAD
240                 ret = pthread_sigmask(SIG_SETMASK, &signal_state.oldset, NULL);
241 #else
242                 ret = sigprocmask(SIG_SETMASK, &signal_state.oldset, NULL);
243                 assert(ret == 0);
244 #endif
245
246                 signal_state.pid = NULL;
247         }
248
249         signal_state.available = true;
250 }
251
252 static void tfork_global_initialize(void)
253 {
254 #ifdef HAVE_PTHREAD
255         int ret;
256
257         pthread_atfork(tfork_atfork_prepare,
258                        tfork_atfork_parent,
259                        tfork_atfork_child);
260
261         ret = pthread_key_create(&tfork_global_key, tfork_global_destructor);
262         assert(ret == 0);
263
264         ret = pthread_mutex_init(&signal_state.mutex, NULL);
265         assert(ret == 0);
266
267         ret = pthread_cond_init(&signal_state.cond, NULL);
268         assert(ret == 0);
269 #endif
270
271         signal_state.available = true;
272 }
273
274 static struct tfork_state *tfork_global_get(void)
275 {
276         struct tfork_state *state = NULL;
277 #ifdef HAVE_PTHREAD
278         int ret;
279 #endif
280
281 #ifdef HAVE_PTHREAD
282         state = (struct tfork_state *)pthread_getspecific(tfork_global_key);
283 #else
284         state = global_state;
285 #endif
286         if (state != NULL) {
287                 return state;
288         }
289
290         state = (struct tfork_state *)anonymous_shared_allocate(
291                 sizeof(struct tfork_state));
292         if (state == NULL) {
293                 return NULL;
294         }
295
296 #ifdef HAVE_PTHREAD
297         ret = pthread_setspecific(tfork_global_key, state);
298         if (ret != 0) {
299                 anonymous_shared_free(state);
300                 return NULL;
301         }
302 #endif
303         return state;
304 }
305
306 static void tfork_global_free(void)
307 {
308         struct tfork_state *state = NULL;
309 #ifdef HAVE_PTHREAD
310         int ret;
311 #endif
312
313 #ifdef HAVE_PTHREAD
314         state = (struct tfork_state *)pthread_getspecific(tfork_global_key);
315 #else
316         state = global_state;
317 #endif
318         if (state == NULL) {
319                 return;
320         }
321
322 #ifdef HAVE_PTHREAD
323         ret = pthread_setspecific(tfork_global_key, NULL);
324         if (ret != 0) {
325                 return;
326         }
327 #endif
328         anonymous_shared_free(state);
329 }
330
331 /**
332  * Only one thread at a time is allowed to handle SIGCHLD signals
333  **/
334 static int tfork_install_sigchld_handler(pid_t *pid)
335 {
336         int ret;
337         struct sigaction act;
338         sigset_t set;
339
340         ret = tfork_acquire_sighandling();
341         if (ret != 0) {
342                 return -1;
343         }
344
345         assert(signal_state.pid == NULL);
346         signal_state.pid = pid;
347
348         act = (struct sigaction) {
349                 .sa_sigaction = tfork_sigchld_handler,
350                 .sa_flags = SA_SIGINFO,
351         };
352
353         ret = sigaction(SIGCHLD, &act, &signal_state.oldact);
354         if (ret != 0) {
355                 return -1;
356         }
357
358         sigemptyset(&set);
359         sigaddset(&set, SIGCHLD);
360 #ifdef HAVE_PTHREAD
361         ret = pthread_sigmask(SIG_UNBLOCK, &set, &signal_state.oldset);
362 #else
363         ret = sigprocmask(SIG_UNBLOCK, &set, &signal_state.oldset);
364 #endif
365         if (ret != 0) {
366                 return -1;
367         }
368
369         return 0;
370 }
371
372 static int tfork_uninstall_sigchld_handler(void)
373 {
374         int ret;
375
376         signal_state.pid = NULL;
377
378         ret = sigaction(SIGCHLD, &signal_state.oldact, NULL);
379         if (ret != 0) {
380                 return -1;
381         }
382
383 #ifdef HAVE_PTHREAD
384         ret = pthread_sigmask(SIG_SETMASK, &signal_state.oldset, NULL);
385 #else
386         ret = sigprocmask(SIG_SETMASK, &signal_state.oldset, NULL);
387 #endif
388         if (ret != 0) {
389                 return -1;
390         }
391
392         ret = tfork_release_sighandling();
393         if (ret != 0) {
394                 return -1;
395         }
396
397         return 0;
398 }
399
400 static void tfork_sigchld_handler(int signum, siginfo_t *si, void *p)
401 {
402         if ((signal_state.pid != NULL) &&
403             (*signal_state.pid != -1) &&
404             (si->si_pid == *signal_state.pid))
405         {
406                 return;
407         }
408
409         /*
410          * Not our child, forward to old handler
411          */
412         if (signal_state.oldact.sa_flags & SA_SIGINFO) {
413                 signal_state.oldact.sa_sigaction(signum, si, p);
414                 return;
415         }
416
417         if (signal_state.oldact.sa_handler == SIG_IGN) {
418                 return;
419         }
420         if (signal_state.oldact.sa_handler == SIG_DFL) {
421                 return;
422         }
423         signal_state.oldact.sa_handler(signum);
424 }
425
426 static pid_t tfork_start_waiter_and_worker(struct tfork_state *state,
427                                            int *_event_fd,
428                                            int *_status_fd)
429 {
430         int p[2];
431         int status_sp_caller_fd = -1;
432         int status_sp_waiter_fd = -1;
433         int event_pipe_caller_fd = -1;
434         int event_pipe_waiter_fd = -1;
435         int ready_pipe_caller_fd = -1;
436         int ready_pipe_worker_fd = -1;
437         ssize_t nwritten;
438         ssize_t nread;
439         pid_t pid;
440         int status;
441         int fd;
442         char c;
443         int ret;
444
445         *_event_fd = -1;
446         *_status_fd = -1;
447
448         if (state == NULL) {
449                 return -1;
450         }
451
452         ret = socketpair(AF_UNIX, SOCK_STREAM, 0, p);
453         if (ret != 0) {
454                 return -1;
455         }
456         set_close_on_exec(p[0]);
457         set_close_on_exec(p[1]);
458         status_sp_caller_fd = p[0];
459         status_sp_waiter_fd = p[1];
460
461         ret = pipe(p);
462         if (ret != 0) {
463                 close(status_sp_caller_fd);
464                 close(status_sp_waiter_fd);
465                 return -1;
466         }
467         set_close_on_exec(p[0]);
468         set_close_on_exec(p[1]);
469         event_pipe_caller_fd = p[0];
470         event_pipe_waiter_fd = p[1];
471
472
473         ret = pipe(p);
474         if (ret != 0) {
475                 close(status_sp_caller_fd);
476                 close(status_sp_waiter_fd);
477                 close(event_pipe_caller_fd);
478                 close(event_pipe_waiter_fd);
479                 return -1;
480         }
481         set_close_on_exec(p[0]);
482         set_close_on_exec(p[1]);
483         ready_pipe_worker_fd = p[0];
484         ready_pipe_caller_fd = p[1];
485
486         pid = fork();
487         if (pid == -1) {
488                 close(status_sp_caller_fd);
489                 close(status_sp_waiter_fd);
490                 close(event_pipe_caller_fd);
491                 close(event_pipe_waiter_fd);
492                 close(ready_pipe_caller_fd);
493                 close(ready_pipe_worker_fd);
494                 return -1;
495         }
496         if (pid != 0) {
497                 /* The caller */
498
499                 state->waiter_pid = pid;
500
501                 close(status_sp_waiter_fd);
502                 close(event_pipe_waiter_fd);
503                 close(ready_pipe_worker_fd);
504
505                 set_blocking(event_pipe_caller_fd, false);
506
507                 /*
508                  * wait for the waiter to get ready.
509                  */
510                 nread = sys_read(status_sp_caller_fd, &c, sizeof(char));
511                 if (nread != sizeof(char)) {
512                         return -1;
513                 }
514
515                 /*
516                  * Notify the worker to start.
517                  */
518                 nwritten = sys_write(ready_pipe_caller_fd,
519                                      &(char){0}, sizeof(char));
520                 if (nwritten != sizeof(char)) {
521                         close(ready_pipe_caller_fd);
522                         return -1;
523                 }
524                 close(ready_pipe_caller_fd);
525
526                 *_event_fd = event_pipe_caller_fd;
527                 *_status_fd = status_sp_caller_fd;
528
529                 return pid;
530         }
531
532 #ifndef HAVE_PTHREAD
533         /* cleanup sigchld_handler */
534         tfork_atfork_child();
535 #endif
536
537         /*
538          * The "waiter" child.
539          */
540         CatchSignal(SIGCHLD, SIG_DFL);
541
542         close(status_sp_caller_fd);
543         close(event_pipe_caller_fd);
544         close(ready_pipe_caller_fd);
545
546         pid = fork();
547         if (pid == -1) {
548                 state->waiter_errno = errno;
549                 _exit(0);
550         }
551         if (pid == 0) {
552                 /*
553                  * The worker child.
554                  */
555
556                 close(status_sp_waiter_fd);
557                 close(event_pipe_waiter_fd);
558
559                 /*
560                  * Wait for the caller to give us a go!
561                  */
562                 nread = sys_read(ready_pipe_worker_fd, &c, sizeof(char));
563                 if (nread != sizeof(char)) {
564                         _exit(1);
565                 }
566                 close(ready_pipe_worker_fd);
567
568                 return 0;
569         }
570         state->worker_pid = pid;
571
572         close(ready_pipe_worker_fd);
573
574         /*
575          * We're going to stay around until child2 exits, so lets close all fds
576          * other then the pipe fd we may have inherited from the caller.
577          *
578          * Dup event_sp_waiter_fd and status_sp_waiter_fd onto fds 0 and 1 so we
579          * can then call closefrom(2).
580          */
581         if (event_pipe_waiter_fd > 0) {
582                 int dup_fd = 0;
583
584                 if (status_sp_waiter_fd == 0) {
585                         dup_fd = 1;
586                 }
587
588                 do {
589                         fd = dup2(event_pipe_waiter_fd, dup_fd);
590                 } while ((fd == -1) && (errno == EINTR));
591                 if (fd == -1) {
592                         state->waiter_errno = errno;
593                         kill(state->worker_pid, SIGKILL);
594                         state->worker_pid = -1;
595                         _exit(1);
596                 }
597                 event_pipe_waiter_fd = fd;
598         }
599
600         if (status_sp_waiter_fd > 1) {
601                 do {
602                         fd = dup2(status_sp_waiter_fd, 1);
603                 } while ((fd == -1) && (errno == EINTR));
604                 if (fd == -1) {
605                         state->waiter_errno = errno;
606                         kill(state->worker_pid, SIGKILL);
607                         state->worker_pid = -1;
608                         _exit(1);
609                 }
610                 status_sp_waiter_fd = fd;
611         }
612
613         closefrom(2);
614
615         /* Tell the caller we're ready */
616         nwritten = sys_write(status_sp_waiter_fd, &(char){0}, sizeof(char));
617         if (nwritten != sizeof(char)) {
618                 _exit(1);
619         }
620
621         tfork_global_free();
622         state = NULL;
623
624         do {
625                 ret = waitpid(pid, &status, 0);
626         } while ((ret == -1) && (errno == EINTR));
627         if (ret == -1) {
628                 status = errno;
629                 kill(pid, SIGKILL);
630         }
631
632         /*
633          * This writes the worker child exit status via our internal socketpair
634          * so the tfork_status() implementation can read it from its end.
635          */
636         nwritten = sys_write(status_sp_waiter_fd, &status, sizeof(status));
637         if (nwritten == -1) {
638                 if (errno != EPIPE && errno != ECONNRESET) {
639                         _exit(errno);
640                 }
641                 /*
642                  * The caller exitted and didn't call tfork_status().
643                  */
644                 _exit(0);
645         }
646         if (nwritten != sizeof(status)) {
647                 _exit(1);
648         }
649
650         /*
651          * This write to the event_fd returned by tfork_event_fd() and notifies
652          * the caller that the worker child is done and he may now call
653          * tfork_status().
654          */
655         nwritten = sys_write(event_pipe_waiter_fd, &(char){0}, sizeof(char));
656         if (nwritten != sizeof(char)) {
657                 _exit(1);
658         }
659
660         /*
661          * Wait for our parent (the process that called tfork_create()) to
662          * close() the socketpair fd in tfork_status().
663          *
664          * Again, the caller might have exitted without calling tfork_status().
665          */
666         nread = sys_read(status_sp_waiter_fd, &c, 1);
667         if (nread == -1) {
668                 if (errno == EPIPE || errno == ECONNRESET) {
669                         _exit(0);
670                 }
671                 _exit(errno);
672         }
673         if (nread != 0) {
674                 _exit(255);
675         }
676
677         _exit(0);
678 }
679
680 static int tfork_create_reap_waiter(pid_t waiter_pid)
681 {
682         pid_t pid;
683         int waiter_status;
684
685         if (waiter_pid == -1) {
686                 return 0;
687         }
688
689         kill(waiter_pid, SIGKILL);
690
691         do {
692                 pid = waitpid(waiter_pid, &waiter_status, 0);
693         } while ((pid == -1) && (errno == EINTR));
694         assert(pid == waiter_pid);
695
696         return 0;
697 }
698
699 struct tfork *tfork_create(void)
700 {
701         struct tfork_state *state = NULL;
702         struct tfork *t = NULL;
703         pid_t pid;
704         int saved_errno;
705         int ret = 0;
706
707 #ifdef HAVE_PTHREAD
708         ret = pthread_once(&tfork_global_is_initialized,
709                            tfork_global_initialize);
710         if (ret != 0) {
711                 return NULL;
712         }
713 #else
714         tfork_global_initialize();
715 #endif
716
717         state = tfork_global_get();
718         if (state == NULL) {
719                 return NULL;
720         }
721         *state = (struct tfork_state) {
722                 .waiter_pid = -1,
723                 .waiter_errno = ECANCELED,
724                 .worker_pid = -1,
725         };
726
727         t = malloc(sizeof(struct tfork));
728         if (t == NULL) {
729                 ret = -1;
730                 goto cleanup;
731         }
732
733         *t = (struct tfork) {
734                 .event_fd = -1,
735                 .status_fd = -1,
736                 .waiter_pid = -1,
737                 .worker_pid = -1,
738         };
739
740         ret = tfork_install_sigchld_handler(&state->waiter_pid);
741         if (ret != 0) {
742                 goto cleanup;
743         }
744
745         pid = tfork_start_waiter_and_worker(state,
746                                             &t->event_fd,
747                                             &t->status_fd);
748         if (pid == -1) {
749                 ret = -1;
750                 goto cleanup;
751         }
752         if (pid == 0) {
753                 /* In the worker */
754                 tfork_global_free();
755                 t->worker_pid = 0;
756                 return t;
757         }
758
759         t->waiter_pid = pid;
760         t->worker_pid = state->worker_pid;
761
762 cleanup:
763         if (ret == -1) {
764                 saved_errno = errno;
765
766                 if (t != NULL) {
767                         if (t->status_fd != -1) {
768                                 close(t->status_fd);
769                         }
770                         if (t->event_fd != -1) {
771                                 close(t->event_fd);
772                         }
773
774                         ret = tfork_create_reap_waiter(state->waiter_pid);
775                         assert(ret == 0);
776
777                         free(t);
778                         t = NULL;
779                 }
780         }
781
782         ret = tfork_uninstall_sigchld_handler();
783         assert(ret == 0);
784
785         tfork_global_free();
786
787         if (ret == -1) {
788                 errno = saved_errno;
789         }
790         return t;
791 }
792
793 pid_t tfork_child_pid(const struct tfork *t)
794 {
795         return t->worker_pid;
796 }
797
798 int tfork_event_fd(const struct tfork *t)
799 {
800         return t->event_fd;
801 }
802
803 int tfork_status(struct tfork **_t, bool wait)
804 {
805         struct tfork *t = *_t;
806         int status;
807         ssize_t nread;
808         int waiter_status;
809         pid_t pid;
810         int ret;
811
812         if (t == NULL) {
813                 return -1;
814         }
815
816         if (wait) {
817                 set_blocking(t->status_fd, true);
818
819                 nread = sys_read(t->status_fd, &status, sizeof(int));
820         } else {
821                 set_blocking(t->status_fd, false);
822
823                 nread = read(t->status_fd, &status, sizeof(int));
824                 if ((nread == -1) &&
825                     ((errno == EAGAIN) || (errno == EWOULDBLOCK) || errno == EINTR)) {
826                         errno = EAGAIN;
827                         return -1;
828                 }
829         }
830         if (nread != sizeof(int)) {
831                 return -1;
832         }
833
834         ret = tfork_install_sigchld_handler(&t->waiter_pid);
835         if (ret != 0) {
836                 return -1;
837         }
838
839         /*
840          * This triggers process exit in the waiter.
841          */
842         close(t->status_fd);
843
844         do {
845                 pid = waitpid(t->waiter_pid, &waiter_status, 0);
846         } while ((pid == -1) && (errno == EINTR));
847         assert(pid == t->waiter_pid);
848
849         close(t->event_fd);
850
851         free(t);
852         t = NULL;
853         *_t = NULL;
854
855         ret = tfork_uninstall_sigchld_handler();
856         assert(ret == 0);
857
858         return status;
859 }
860
861 int tfork_destroy(struct tfork **_t)
862 {
863         struct tfork *t = *_t;
864         int ret;
865
866         if (t == NULL) {
867                 errno = EINVAL;
868                 return -1;
869         }
870
871         kill(t->worker_pid, SIGKILL);
872
873         ret = tfork_status(_t, true);
874         if (ret == -1) {
875                 return -1;
876         }
877
878         return 0;
879 }