uwrap: Check the value of UID_WRAPPER env variable.
[uid_wrapper.git] / src / uid_wrapper.c
1 /*
2  * Copyright (c) 2009      Andrew Tridgell
3  * Copyright (c) 2011-2013 Andreas Schneider <asn@samba.org>
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17  */
18
19 #include "config.h"
20
21 #include <errno.h>
22 #include <stdarg.h>
23 #include <stdbool.h>
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <string.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 #include <grp.h>
30 #ifdef HAVE_SYS_SYSCALL_H
31 #include <sys/syscall.h>
32 #endif
33 #ifdef HAVE_SYSCALL_H
34 #include <syscall.h>
35 #endif
36 #include <dlfcn.h>
37
38 #ifdef HAVE_GCC_THREAD_LOCAL_STORAGE
39 # define UWRAP_THREAD __thread
40 #else
41 # define UWRAP_THREAD
42 #endif
43
44 #ifdef NDEBUG
45 #define UWRAP_DEBUG(...)
46 #else
47 #define UWRAP_DEBUG(...) fprintf(stderr, __VA_ARGS__)
48 #endif
49
50 #define LIBC_NAME "libc.so"
51
52 struct uwrap_libc_fns {
53         int (*_libc_setuid)(uid_t uid);
54         uid_t (*_libc_getuid)(void);
55
56 #ifdef HAVE_SETEUID
57         int (*_libc_seteuid)(uid_t euid);
58 #endif
59 #ifdef HAVE_SETREUID
60         int (*_libc_setreuid)(uid_t ruid, uid_t euid);
61 #endif
62 #ifdef HAVE_SETREUID
63         int (*_libc_setresuid)(uid_t ruid, uid_t euid, uid_t suid);
64 #endif
65         uid_t (*_libc_geteuid)(void);
66
67         int (*_libc_setgid)(gid_t gid);
68         gid_t (*_libc_getgid)(void);
69 #ifdef HAVE_SETEGID
70         int (*_libc_setegid)(uid_t egid);
71 #endif
72 #ifdef HAVE_SETREGID
73         int (*_libc_setregid)(uid_t rgid, uid_t egid);
74 #endif
75 #ifdef HAVE_SETREGID
76         int (*_libc_setresgid)(uid_t rgid, uid_t egid, uid_t sgid);
77 #endif
78         gid_t (*_libc_getegid)(void);
79         int (*_libc_getgroups)(int size, gid_t list[]);
80         int (*_libc_setgroups)(size_t size, const gid_t *list);
81 #ifdef HAVE_SYSCALL
82         long int (*_libc_syscall)(long int sysno, ...);
83 #endif
84 };
85
86 /*
87  * We keep the virtualised euid/egid/groups information here
88  */
89 struct uwrap {
90         struct {
91                 void *handle;
92                 struct uwrap_libc_fns fns;
93         } libc;
94         bool initialised;
95         bool enabled;
96         uid_t myuid;
97         uid_t ruid;
98         uid_t euid;
99         uid_t suid;
100         uid_t mygid;
101         gid_t rgid;
102         gid_t egid;
103         gid_t sgid;
104         gid_t *groups;
105         int ngroups;
106 };
107
108 static struct uwrap uwrap;
109
110 static void *uwrap_libc_fn(struct uwrap *u, const char *fn_name)
111 {
112         void *func;
113
114         if (u->libc.handle == NULL) {
115                 return NULL;
116         }
117
118         func = dlsym(u->libc.handle, fn_name);
119         if (func == NULL) {
120                 printf("Failed to find %s in %s: %s\n",
121                                 fn_name, LIBC_NAME, dlerror());
122                 exit(-1);
123         }
124
125         return func;
126 }
127
128 static void uwrap_libc_init(struct uwrap *u)
129 {
130         unsigned int i;
131         int flags = RTLD_LAZY;
132
133 #ifdef RTLD_DEEPBIND
134         flags |= RTLD_DEEPBIND;
135 #endif
136
137         for (u->libc.handle = NULL, i = 10; u->libc.handle == NULL; i--) {
138                 char soname[256] = {0};
139
140                 snprintf(soname, sizeof(soname), "%s.%u", LIBC_NAME, i);
141                 u->libc.handle = dlopen(soname, flags);
142         }
143
144         if (u->libc.handle == NULL) {
145                 printf("Failed to dlopen %s.%u: %s\n", LIBC_NAME, i, dlerror());
146                 exit(-1);
147         }
148
149         *(void **) (&u->libc.fns._libc_setuid) = uwrap_libc_fn(u, "setuid");
150         *(void **) (&u->libc.fns._libc_getuid) = uwrap_libc_fn(u, "getuid");
151
152 #ifdef HAVE_SETEUID
153         *(void **) (&u->libc.fns._libc_seteuid) = uwrap_libc_fn(u, "seteuid");
154 #endif
155 #ifdef HAVE_SETREUID
156         *(void **) (&u->libc.fns._libc_setreuid) = uwrap_libc_fn(u, "setreuid");
157 #endif
158 #ifdef HAVE_SETRESUID
159         *(void **) (&u->libc.fns._libc_setresuid) = uwrap_libc_fn(u, "setresuid");
160 #endif
161         *(void **) (&u->libc.fns._libc_geteuid) = uwrap_libc_fn(u, "geteuid");
162
163         *(void **) (&u->libc.fns._libc_setgid) = uwrap_libc_fn(u, "setgid");
164         *(void **) (&u->libc.fns._libc_getgid) = uwrap_libc_fn(u, "getgid");
165 #ifdef HAVE_SETEGID
166         *(void **) (&u->libc.fns._libc_setegid) = uwrap_libc_fn(u, "setegid");
167 #endif
168 #ifdef HAVE_SETREGID
169         *(void **) (&u->libc.fns._libc_setregid) = uwrap_libc_fn(u, "setregid");
170 #endif
171 #ifdef HAVE_SETRESGID
172         *(void **) (&u->libc.fns._libc_setresgid) = uwrap_libc_fn(u, "setresgid");
173 #endif
174         *(void **) (&u->libc.fns._libc_getegid) = uwrap_libc_fn(u, "getegid");
175         *(void **) (&u->libc.fns._libc_getgroups) = uwrap_libc_fn(u, "getgroups");
176         *(void **) (&u->libc.fns._libc_setgroups) = uwrap_libc_fn(u, "setgroups");
177         *(void **) (&u->libc.fns._libc_getuid) = uwrap_libc_fn(u, "getuid");
178         *(void **) (&u->libc.fns._libc_getgid) = uwrap_libc_fn(u, "getgid");
179 #ifdef HAVE_SYSCALL
180         *(void **) (&u->libc.fns._libc_syscall) = uwrap_libc_fn(u, "syscall");
181 #endif
182 }
183
184 static void uwrap_init(void)
185 {
186         const char *env = getenv("UID_WRAPPER");
187
188         if (uwrap.initialised) {
189                 return;
190         }
191
192         uwrap_libc_init(&uwrap);
193
194         uwrap.initialised = true;
195         uwrap.enabled = false;
196
197         if (env != NULL && env[0] == '1') {
198                 const char *root = getenv("UID_WRAPPER_ROOT");
199                 uwrap.enabled = true;
200                 /* put us in one group */
201                 if (root != NULL && root[0] == '1') {
202                         uwrap.myuid = 0;
203                         uwrap.mygid = 0;
204                 } else {
205                         uwrap.myuid = uwrap.libc.fns._libc_geteuid();
206                         uwrap.mygid = uwrap.libc.fns._libc_getegid();
207                 }
208
209                 uwrap.ruid = uwrap.euid = uwrap.suid = uwrap.myuid;
210                 uwrap.rgid = uwrap.egid = uwrap.sgid = uwrap.mygid;
211
212                 uwrap.ngroups = 1;
213                 uwrap.groups = malloc(sizeof(gid_t) * uwrap.ngroups);
214                 uwrap.groups[0] = uwrap.mygid;
215         }
216 }
217
218 static int uwrap_enabled(void)
219 {
220         uwrap_init();
221
222         return uwrap.enabled ? 1 : 0;
223 }
224
225 static int uwrap_setresuid(uid_t ruid, uid_t euid, uid_t suid)
226 {
227         if (ruid == (uid_t)-1 && euid == (uid_t)-1 && suid == (uid_t)-1) {
228                 errno = EINVAL;
229                 return -1;
230         }
231
232         if (ruid != (uid_t)-1) {
233                 uwrap.ruid = ruid;
234         }
235
236         if (euid != (uid_t)-1) {
237                 uwrap.euid = euid;
238         }
239
240         if (suid != (uid_t)-1) {
241                 uwrap.suid = suid;
242         }
243
244         return 0;
245 }
246
247 /*
248  * SETUID
249  */
250 int setuid(uid_t uid)
251 {
252         if (!uwrap_enabled()) {
253                 return uwrap.libc.fns._libc_setuid(uid);
254         }
255
256         return uwrap_setresuid(uid, -1, -1);
257 }
258
259 /*
260  * GETUID
261  */
262 static uid_t uwrap_getuid(void)
263 {
264         return uwrap.ruid;
265 }
266
267 uid_t getuid(void)
268 {
269         if (!uwrap_enabled()) {
270                 return uwrap.libc.fns._libc_getuid();
271         }
272
273         return uwrap_getuid();
274 }
275
276 #ifdef HAVE_SETEUID
277 int seteuid(uid_t euid)
278 {
279         if (euid == (uid_t)-1) {
280                 errno = EINVAL;
281                 return -1;
282         }
283
284         if (!uwrap_enabled()) {
285                 return uwrap.libc.fns._libc_seteuid(euid);
286         }
287
288         return uwrap_setresuid(-1, euid, -1);
289 }
290 #endif
291
292 #ifdef HAVE_SETREUID
293 int setreuid(uid_t ruid, uid_t euid)
294 {
295         if (ruid == (uid_t)-1 && euid == (uid_t)-1) {
296                 errno = EINVAL;
297                 return -1;
298         }
299
300         if (!uwrap_enabled()) {
301                 return uwrap.libc.fns._libc_setreuid(ruid, euid);
302         }
303
304         return uwrap_setresuid(ruid, euid, -1);
305 }
306 #endif
307
308 #ifdef HAVE_SETRESUID
309 int setresuid(uid_t ruid, uid_t euid, uid_t suid)
310 {
311         if (!uwrap_enabled()) {
312                 return uwrap.libc.fns._libc_setresuid(ruid, euid, suid);
313         }
314
315         return uwrap_setresuid(ruid, euid, suid);
316 }
317 #endif
318
319 static uid_t uwrap_geteuid(void)
320 {
321         return uwrap.euid;
322 }
323
324 uid_t geteuid(void)
325 {
326         if (!uwrap_enabled()) {
327                 return uwrap.libc.fns._libc_geteuid();
328         }
329
330         return uwrap_geteuid();
331 }
332
333 /*
334  * SETGID
335  */
336 static int uwrap_setgid(gid_t gid)
337 {
338         if (gid == (gid_t)-1) {
339                 errno = EINVAL;
340                 return -1;
341         }
342
343         uwrap.rgid = gid;
344
345         return 0;
346 }
347
348 int setgid(gid_t gid)
349 {
350         if (!uwrap_enabled()) {
351                 return uwrap.libc.fns._libc_setgid(gid);
352         }
353
354         return uwrap_setgid(gid);
355 }
356
357 /*
358  * GETGID
359  */
360 static gid_t uwrap_getgid(void)
361 {
362         return uwrap.rgid;
363 }
364
365 gid_t getgid(void)
366 {
367         if (!uwrap_enabled()) {
368                 return uwrap.libc.fns._libc_getgid();
369         }
370
371         return uwrap_getgid();
372 }
373
374 static int uwrap_setresgid(gid_t rgid, gid_t egid, gid_t sgid)
375 {
376         if (rgid == (gid_t)-1 && egid == (gid_t)-1 && sgid == (gid_t)-1) {
377                 errno = EINVAL;
378                 return -1;
379         }
380
381         if (rgid != (gid_t)-1) {
382                 uwrap.rgid = rgid;
383         }
384
385         if (egid != (gid_t)-1) {
386                 uwrap.egid = egid;
387         }
388
389         if (sgid != (gid_t)-1) {
390                 uwrap.sgid = sgid;
391         }
392
393         return 0;
394 }
395
396 #ifdef HAVE_SETEGID
397 int setegid(gid_t egid)
398 {
399         if (!uwrap_enabled()) {
400                 return uwrap.libc.fns._libc_setegid(egid);
401         }
402
403         return uwrap_setresgid(-1, egid, -1);
404 }
405 #endif
406
407 #ifdef HAVE_SETREGID
408 int setregid(gid_t rgid, gid_t egid)
409 {
410         if (!uwrap_enabled()) {
411                 return uwrap.libc.fns._libc_setregid(rgid, egid);
412         }
413
414         return uwrap_setresgid(rgid, egid, -1);
415 }
416 #endif
417
418 #ifdef HAVE_SETRESGID
419 int setresgid(gid_t rgid, gid_t egid, gid_t sgid)
420 {
421         if (!uwrap_enabled()) {
422                 return uwrap.libc.fns._libc_setregid(rgid, egid, sgid);
423         }
424
425         return uwrap_setresgid(rgid, egid, sgid);
426 }
427 #endif
428
429 static uid_t uwrap_getegid(void)
430 {
431         return uwrap.egid;
432 }
433
434 uid_t getegid(void)
435 {
436         if (!uwrap_enabled()) {
437                 return uwrap.libc.fns._libc_getegid();
438         }
439
440         return uwrap_getegid();
441 }
442
443 static int uwrap_setgroups(size_t size, const gid_t *list)
444 {
445         free(uwrap.groups);
446         uwrap.groups = NULL;
447         uwrap.ngroups = 0;
448
449         if (size != 0) {
450                 uwrap.groups = malloc(sizeof(gid_t) * size);
451                 if (uwrap.groups == NULL) {
452                         errno = ENOMEM;
453                         return -1;
454                 }
455                 uwrap.ngroups = size;
456                 memcpy(uwrap.groups, list, size*sizeof(gid_t));
457         }
458
459         return 0;
460 }
461
462 int setgroups(size_t size, const gid_t *list)
463 {
464         if (!uwrap_enabled()) {
465                 return uwrap.libc.fns._libc_setgroups(size, list);
466         }
467
468         return uwrap_setgroups(size, list);
469 }
470
471 static int uwrap_getgroups(int size, gid_t *list)
472 {
473         int ngroups;
474
475         ngroups = uwrap.ngroups;
476
477         if (size > ngroups) {
478                 size = ngroups;
479         }
480         if (size == 0) {
481                 return ngroups;
482         }
483         if (size < ngroups) {
484                 errno = EINVAL;
485                 return -1;
486         }
487         memcpy(list, uwrap.groups, size*sizeof(gid_t));
488
489         return ngroups;
490 }
491
492 int getgroups(int size, gid_t *list)
493 {
494         if (!uwrap_enabled()) {
495                 return uwrap.libc.fns._libc_getgroups(size, list);
496         }
497
498         return uwrap_getgroups(size, list);
499 }
500
501 static long int libc_vsyscall(long int sysno, va_list va)
502 {
503         long int args[8];
504         long int rc;
505         int i;
506
507         for (i = 0; i < 8; i++) {
508                 args[i] = va_arg(va, long int);
509         }
510
511         rc = uwrap.libc.fns._libc_syscall(sysno,
512                                           args[0],
513                                           args[1],
514                                           args[2],
515                                           args[3],
516                                           args[4],
517                                           args[5],
518                                           args[6],
519                                           args[7]);
520
521         return rc;
522 }
523
524 #if (defined(HAVE_SYS_SYSCALL_H) || defined(HAVE_SYSCALL_H)) \
525     && (defined(SYS_setreuid) || defined(SYS_setreuid32))
526 static long int uwrap_syscall (long int sysno, va_list vp)
527 {
528         long int rc;
529
530         switch (sysno) {
531                 /* gid */
532                 case SYS_setgid:
533 #ifdef HAVE_LINUX_32BIT_SYSCALLS
534                 case SYS_setgid32:
535 #endif
536                         {
537                                 gid_t gid = (gid_t) va_arg(vp, int);
538
539                                 rc = uwrap_setresgid(gid, -1, -1);
540                         }
541                         break;
542                 case SYS_setregid:
543 #ifdef HAVE_LINUX_32BIT_SYSCALLS
544                 case SYS_setregid32:
545 #endif
546                         {
547                                 uid_t rgid = (uid_t) va_arg(vp, int);
548                                 uid_t egid = (uid_t) va_arg(vp, int);
549
550                                 rc = uwrap_setresgid(rgid, egid, -1);
551                         }
552                         break;
553                 case SYS_setresgid:
554 #ifdef HAVE_LINUX_32BIT_SYSCALLS
555                 case SYS_setresgid32:
556 #endif
557                         {
558                                 uid_t rgid = (uid_t) va_arg(vp, int);
559                                 uid_t egid = (uid_t) va_arg(vp, int);
560                                 uid_t sgid = (uid_t) va_arg(vp, int);
561
562                                 rc = uwrap_setresgid(rgid, egid, sgid);
563                         }
564                         break;
565
566                 /* uid */
567                 case SYS_setuid:
568 #ifdef HAVE_LINUX_32BIT_SYSCALLS
569                 case SYS_setuid32:
570 #endif
571                         {
572                                 uid_t uid = (uid_t) va_arg(vp, int);
573
574                                 rc = uwrap_setresuid(uid, -1, -1);
575                         }
576                         break;
577                 case SYS_setreuid:
578 #ifdef HAVE_LINUX_32BIT_SYSCALLS
579                 case SYS_setreuid32:
580 #endif
581                         {
582                                 uid_t ruid = (uid_t) va_arg(vp, int);
583                                 uid_t euid = (uid_t) va_arg(vp, int);
584
585                                 rc = uwrap_setresuid(ruid, euid, -1);
586                         }
587                         break;
588                 case SYS_setresuid:
589 #ifdef HAVE_LINUX_32BIT_SYSCALLS
590                 case SYS_setresuid32:
591 #endif
592                         {
593                                 uid_t ruid = (uid_t) va_arg(vp, int);
594                                 uid_t euid = (uid_t) va_arg(vp, int);
595                                 uid_t suid = (uid_t) va_arg(vp, int);
596
597                                 rc = uwrap_setresuid(ruid, euid, suid);
598                         }
599                         break;
600
601                 /* groups */
602                 case SYS_setgroups:
603 #ifdef HAVE_LINUX_32BIT_SYSCALLS
604                 case SYS_setgroups32:
605 #endif
606                         {
607                                 size_t size = (size_t) va_arg(vp, size_t);
608                                 gid_t *list = (gid_t *) va_arg(vp, int *);
609
610                                 rc = uwrap_setgroups(size, list);
611                         }
612                         break;
613 #ifdef SYS_initgroups
614                 case SYS_initgroups:
615 #ifdef HAVE_LINUX_32BIT_SYSCALLS
616                 case SYS_initgroups32:
617 #endif
618                         {
619                                 const char *user = (const char *) va_arg(vp, char*);
620                                 gid_t group = (gid_t) va_arg(vp, int);
621
622                                 rc = initgroups(user, group);
623                         }
624                         break;
625 #endif
626                 default:
627                         UWRAP_DEBUG("UID_WRAPPER calling non-wrapped syscall "
628                                     "%lu\n", sysno);
629
630                         rc = libc_vsyscall(sysno, vp);
631                         break;
632         }
633
634         return rc;
635 }
636
637 #ifdef HAVE_SYSCALL
638 long int syscall (long int sysno, ...)
639 {
640         long int rc;
641         va_list va;
642
643         va_start(va, sysno);
644
645         if (!uwrap_enabled()) {
646                 rc = libc_vsyscall(sysno, va);
647                 va_end(va);
648                 return rc;
649         }
650
651         rc = uwrap_syscall(sysno, va);
652         va_end(va);
653
654         return rc;
655 }
656 #endif /* HAVE_SYSCALL */
657 #endif /* HAVE_SYS_SYSCALL_H || HAVE_SYSCALL_H */