3a5d1bdeb676426fe681140469be716bfc62c49b
[obnox/samba-ctdb.git] / source3 / utils / net_lua.c
1 /*
2  *  Unix SMB/CIFS implementation.
3  *  Lua experiments
4  *  Copyright (C) Volker Lendecke 2006
5  *
6  *  This program is free software; you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation; either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This program is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with this program; if not, see <http://www.gnu.org/licenses/>.
18  */
19
20
21 #include "includes.h"
22 #include "utils/net.h"
23
24 #include "lua-5.1.4/src/lualib.h"
25 #include "lua-5.1.4/src/lauxlib.h"
26
27 #define SOCK_METATABLE "cade1208-9029-4d76-8748-426dfc1436f7"
28
29 struct sock_userdata {
30         int fd;
31 };
32
33 static int sock_userdata_gc(lua_State *L)
34 {
35         struct sock_userdata *p = (struct sock_userdata *)
36                 luaL_checkudata(L, 1, SOCK_METATABLE);
37         close(p->fd);
38         return 0;
39 }
40
41 static int sock_userdata_tostring(lua_State *L)
42 {
43         struct sock_userdata *p = (struct sock_userdata *)
44                 luaL_checkudata(L, 1, SOCK_METATABLE);
45
46         lua_pushfstring(L, "socket: %d", p->fd);
47         return 1;
48 }
49
50 static int sock_userdata_connect(lua_State *L)
51 {
52         struct sock_userdata *p = (struct sock_userdata *)
53                 luaL_checkudata(L, 1, SOCK_METATABLE);
54         const char *hostname;
55         int port;
56         struct sockaddr_in addr;
57         int res;
58
59         if (!lua_isstring(L, 2)) {
60                 luaL_error(L, "connect: Expected IP-Address");
61         }
62         hostname = lua_tostring(L, 2);
63
64         if (!lua_isnumber(L, 3)) {
65                 luaL_error(L, "connect: Expected port");
66         }
67         port = lua_tointeger(L, 3);
68
69         if (lua_gettop(L) == 4) {
70                 /*
71                  * Here we expect an event context in the last argument to
72                  * make connect() asynchronous.
73                  */
74         }
75
76         addr.sin_family = AF_INET;
77         inet_aton(hostname, &addr.sin_addr);
78         addr.sin_port = htons(port);
79
80         res = connect(p->fd, (struct sockaddr *)&addr, sizeof(addr));
81         if (res == -1) {
82                 int err = errno;
83                 lua_pushnil(L);
84                 lua_pushfstring(L, "connect failed: %s", strerror(err));
85                 return 2;
86         }
87
88         lua_pushboolean(L, 1);
89         return 1;
90 }
91
92 static const struct luaL_Reg sock_methods[] = {
93         {"__gc",        sock_userdata_gc},
94         {"__tostring",  sock_userdata_tostring},
95         {"connect",     sock_userdata_connect},
96         {NULL, NULL}
97 };
98
99 static const struct {
100         const char *name;
101         int domain;
102 } socket_domains[] = {
103         {"PF_UNIX", PF_UNIX},
104         {"PF_INET", PF_INET},
105         {NULL, 0},
106 };
107
108 static const struct {
109         const char *name;
110         int type;
111 } socket_types[] = {
112         {"SOCK_STREAM", SOCK_STREAM},
113         {"SOCK_DGRAM", SOCK_DGRAM},
114         {NULL, 0},
115 };
116
117 static int sock_userdata_new(lua_State *L)
118 {
119         struct sock_userdata *result;
120         const char *domain_str = luaL_checkstring(L, 1);
121         const char *type_str = luaL_checkstring(L, 2);
122         int i, domain, type;
123
124         i = 0;
125         while (socket_domains[i].name != NULL) {
126                 if (strcmp(domain_str, socket_domains[i].name) == 0) {
127                         break;
128                 }
129                 i += 1;
130         }
131         if (socket_domains[i].name == NULL) {
132                 return luaL_error(L, "socket domain %s unknown", domain_str);
133         }
134         domain = socket_domains[i].domain;
135
136         i = 0;
137         while (socket_types[i].name != NULL) {
138                 if (strcmp(type_str, socket_types[i].name) == 0) {
139                         break;
140                 }
141                 i += 1;
142         }
143         if (socket_types[i].name == NULL) {
144                 return luaL_error(L, "socket type %s unknown", type_str);
145         }
146         type = socket_types[i].type;
147
148         result = (struct sock_userdata *)lua_newuserdata(L, sizeof(*result));
149         ZERO_STRUCTP(result);
150
151         result->fd = socket(domain, type, 0);
152         if (result->fd == -1) {
153                 int err = errno;
154                 lua_pushnil(L);
155                 lua_pushfstring(L, "socket() failed: %s", strerror(errno));
156                 lua_pushinteger(L, err);
157                 return 3;
158         }
159
160         luaL_getmetatable(L, SOCK_METATABLE);
161         lua_setmetatable(L, -2);
162         return 1;
163 }
164
165 static const struct luaL_Reg sock_funcs[] = {
166         {"new",         sock_userdata_new},
167         {NULL, NULL}
168 };
169
170 static int sock_lua_init(lua_State *L, const char *libname) {
171         luaL_newmetatable(L, SOCK_METATABLE);
172
173         lua_pushvalue(L, -1);
174         lua_setfield(L, -2, "__index");
175
176         luaL_register(L, NULL, sock_methods);
177         luaL_register(L, libname, sock_funcs);
178         return 1;
179 }
180
181 #define EVT_METATABLE "c42e0642-b24a-40f0-8483-d8eb4aee9ea3"
182
183 /*
184  * The userdata we allocate from lua when a new event context is created
185  */
186 struct evt_userdata {
187         struct event_context *ev;
188 };
189
190 static bool evt_is_main_thread(lua_State *L) {
191         int ret;
192
193         ret = lua_pushthread(L);
194         lua_pop(L, 1);
195         return (ret != 0);
196 }
197
198 /*
199  * Per event we allocate a struct thread_reference to keep the coroutine from
200  * being garbage-collected. This is also the hook to find the right thread to
201  * be resumed.
202  */
203
204 struct thread_reference {
205         struct lua_State *L;
206         /*
207          * Reference to the Thread (i.e. lua_State) this event is hanging on
208          */
209         int thread_ref;
210 };
211
212 static int thread_reference_destructor(struct thread_reference *ref)
213 {
214         luaL_unref(ref->L, LUA_REGISTRYINDEX, ref->thread_ref);
215         return 0;
216 }
217
218 static struct thread_reference *evt_reference_thread(TALLOC_CTX *mem_ctx,
219                                                      lua_State *L)
220 {
221         struct thread_reference *result;
222
223         result = talloc(mem_ctx, struct thread_reference);
224         if (result == NULL) {
225                 return NULL;
226         }
227
228         lua_pushthread(L);
229         result->thread_ref = luaL_ref(L, LUA_REGISTRYINDEX);
230         result->L = L;
231         talloc_set_destructor(result, thread_reference_destructor);
232
233         return result;
234 }
235
236 static int evt_userdata_gc(lua_State *L)
237 {
238         struct evt_userdata *p = (struct evt_userdata *)
239                 luaL_checkudata(L, 1, EVT_METATABLE);
240         TALLOC_FREE(p->ev);
241         return 0;
242 }
243
244 static int evt_userdata_tostring(lua_State *L) {
245         lua_pushstring(L, "event context");
246         return 1;
247 }
248
249 static void evt_userdata_sleep_done(struct event_context *event_ctx,
250                                    struct timed_event *te,
251                                    const struct timeval *now,
252                                    void *priv)
253 {
254         struct thread_reference *ref = talloc_get_type_abort(
255                 priv, struct thread_reference);
256         lua_resume(ref->L, 0);
257         TALLOC_FREE(ref);
258 }
259
260 static int evt_userdata_sleep(lua_State *L)
261 {
262         struct evt_userdata *p = (struct evt_userdata *)
263                 luaL_checkudata(L, 1, EVT_METATABLE);
264         lua_Integer usecs = luaL_checkint(L, 2);
265         struct thread_reference *ref;
266         struct timed_event *te;
267
268         if (evt_is_main_thread(L)) {
269                 /*
270                  * Block in the main thread
271                  */
272                 smb_msleep(usecs/1000);
273                 return 0;
274         }
275
276         ref = evt_reference_thread(p->ev, L);
277         if (ref == NULL) {
278                 return luaL_error(L, "evt_reference_thread failed\n");
279         }
280
281         te = event_add_timed(p->ev, ref, timeval_current_ofs(0, usecs),
282                              "evt_userdata_sleep", evt_userdata_sleep_done,
283                              ref);
284
285         if (te == NULL) {
286                 TALLOC_FREE(ref);
287                 return luaL_error(L, "event_add_timed failed");
288         }
289
290         return lua_yield(L, 0);
291 }
292
293 static int evt_userdata_once(lua_State *L)
294 {
295         struct evt_userdata *p = (struct evt_userdata *)
296                 luaL_checkudata(L, 1, EVT_METATABLE);
297
298         if (!evt_is_main_thread(L)) {
299                 return luaL_error(L, "event_once called from non-base thread");
300         }
301
302         lua_pushinteger(L, event_loop_once(p->ev));
303         return 1;
304 }
305
306 static const struct luaL_Reg evt_methods[] = {
307         {"__gc",        evt_userdata_gc},
308         {"__tostring",  evt_userdata_tostring},
309         {"sleep",       evt_userdata_sleep},
310         {"once",        evt_userdata_once},
311         {NULL, NULL}
312 };
313
314 static int evt_userdata_new(lua_State *L) {
315         struct evt_userdata *result;
316
317         result = (struct evt_userdata *)lua_newuserdata(L, sizeof(*result));
318         ZERO_STRUCTP(result);
319
320         result->ev = event_context_init(NULL);
321         if (result->ev == NULL) {
322                 return luaL_error(L, "event_context_init failed");
323         }
324
325         luaL_getmetatable(L, EVT_METATABLE);
326         lua_setmetatable(L, -2);
327         return 1;
328 }
329
330 static const struct luaL_Reg evt_funcs[] = {
331         {"new",         evt_userdata_new},
332         {NULL, NULL}
333 };
334
335 static int evt_lua_init(lua_State *L, const char *libname) {
336         luaL_newmetatable(L, EVT_METATABLE);
337
338         lua_pushvalue(L, -1);
339         lua_setfield(L, -2, "__index");
340
341         luaL_register(L, NULL, evt_methods);
342         luaL_register(L, libname, evt_funcs);
343         return 1;
344 }
345
346 int net_lua(struct net_context *c, int argc, const char **argv)
347 {
348         lua_State *state;
349
350         state = lua_open();
351         if (state == NULL) {
352                 d_fprintf(stderr, "lua_newstate failed\n");
353                 return -1;
354         }
355
356         luaL_openlibs(state);
357         evt_lua_init(state, "event");
358         sock_lua_init(state, "socket");
359
360         while (1) {
361                 char *line = NULL;
362
363                 line = smb_readline("lua> ", NULL, NULL);
364                 if (line == NULL) {
365                         break;
366                 }
367
368                 if (line[0] == ':') {
369                         if (luaL_dofile(state, &line[1])) {
370                                 d_printf("luaL_dofile returned an error\n");
371                                 continue;
372                         }
373                 } else if (line[0] != '\n') {
374                         if (luaL_dostring(state, line) != 0) {
375                                 d_printf("luaL_dostring returned an error\n");
376                         }
377                 }
378
379                 SAFE_FREE(line);
380         }
381
382         lua_close(state);
383         return -1;
384 }