Merge in rpcproxy (wsgi based)
[jelmer/openchange.git] / mapiproxy / services / ocsmanager / rpcproxy / rpcproxy / channels.py
1 # channels.py -- OpenChange RPC-over-HTTP implementation
2 #
3 # Copyright (C) 2012  Julien Kerihuel <j.kerihuel@openchange.org>
4 #                     Wolfgang Sourdeau <wsourdeau@inverse.ca>
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10 #   
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #   
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18 #
19
20 import os
21 from select import poll, POLLIN, POLLHUP
22 from socket import socket, AF_INET, AF_UNIX, SOCK_STREAM, MSG_WAITALL, \
23     error as socket_error
24 from struct import pack, unpack_from
25 import sys
26 from time import time, sleep
27 from uuid import UUID
28
29 # from rpcproxy.RPCH import RPCH, RTS_FLAG_ECHO
30 from fdunix import send_socket, receive_socket
31 from packets import RTS_CMD_CONNECTION_TIMEOUT, RTS_CMD_VERSION, \
32     RTS_CMD_RECEIVE_WINDOW_SIZE, RTS_CMD_CONNECTION_TIMEOUT, \
33     RTS_FLAG_ECHO, \
34     RPCPacket, RPCRTSPacket, RPCRTSOutPacket
35
36
37 """Documentation:
38
39 1. "Connection Establishment" sequence (from RPCH.pdf, 3.2.1.5.3.1)
40
41   client -> IN request -> proxy in
42   # server -> legacy server response -> proxy in
43   # server -> legacy server response -> proxy out
44   client -> Out request -> proxy out
45   client -> A1 -> proxy out
46   client -> B1 -> proxy in
47   # proxy out -> A2 -> server
48   proxy out -> OUT channel response -> client
49   # proxy in -> B2 -> server
50   proxy out -> A3 -> client
51   # server -> C1 -> proxy out
52   # server -> B3 -> proxy in
53   proxy out -> C2 -> client
54
55 2. internal unix socket protocols
56
57    Note: OUT proxy is always the server
58
59  * establishing virtual connection
60  OUT proxy listens on unix socket
61  IN proxy connects to OUT proxy
62  IN -> OUT: "IP"
63  IN -> OUT: in_window_size
64  IN -> OUT: in_conn_timeout
65  OUT -> IN: sends connection to OpenChange
66  (TODO: socket close at this point?)
67
68  * channel recycling (unused yet, hypothethical)
69  When new OUT conn arrives:
70  new OUT -> OUT: "OP"
71  OUT -> new OUT: OUT listening socket (fdunix)
72  OUT -> new OUT: IN socket (fdunix)
73  OUT -> new OUT: oc socket (fdunix)
74  close OUT socket locally
75 """
76
77
78 # those id must have the same length
79 INBOUND_PROXY_ID = "IP"
80 OUTBOUND_PROXY_ID = "OP"
81 SOCKETS_DIR = "/tmp/rpcproxy"
82 OC_HOST = "127.0.0.1"
83
84 class RPCProxyChannelHandler(object):
85     def __init__(self, logger):
86         self.logger = logger
87
88         self.client_socket = None # placeholder for wsgi.input
89
90         self.bytes_read = 0
91         self.bytes_written = 0
92         self.startup_time = time()
93
94         self.channel_cookie = None
95         self.connection_cookie = None
96
97     def handle_echo_request(self, environ, start_response):
98         self.logger.info("handling echo request")
99
100         packet = RPCRTSOutPacket()
101         packet.flags = RTS_FLAG_ECHO
102         data = packet.make()
103         self.bytes_written = self.bytes_written + packet.size
104
105         start_response("200 Success", [("Content-length", "%d" % packet.size),
106                                        ("Content-Type", "application/rpc")])
107
108         return [data]
109
110     def log_connection_stats(self):
111         self.logger.info("request took %f secs; %d bytes received; %d bytes sent"
112                          % ((time() - self.startup_time),
113                             self.bytes_read, self.bytes_written))
114
115
116 class RPCProxyInboundChannelHandler(RPCProxyChannelHandler):
117     def __init__(self, logger):
118         RPCProxyChannelHandler.__init__(self, logger)
119         self.oc_conn = None
120         self.window_size = 0
121         self.conn_timeout = 0
122         self.client_keepalive = 0
123         self.association_group_id = None
124
125     def _receive_conn_b1(self):
126         # CONN/B1 RTS PDU (TODO: validation)
127         # receive the cookie
128         self.logger.info("IN: receiving CONN/B1")
129
130         packet = RPCPacket.from_file(self.client_socket, self.logger)
131         if not isinstance(packet, RPCRTSPacket):
132             raise Exception("Unexpected non-rts packet received for CONN/B1")
133         self.connection_cookie = str(UUID(bytes=packet.commands[1]["Cookie"]))
134         self.channel_cookie = str(UUID(bytes=packet.commands[2]["Cookie"]))
135         self.client_keepalive = packet.commands[4]["ClientKeepalive"]
136         self.association_group_id = str(UUID(bytes=packet.commands[5] \
137                                                  ["AssociationGroupId"]))
138         self.bytes_read = self.bytes_read + packet.size
139
140     def _connect_to_OUT_channel(self):
141         # FIXME: we might need to keep a persistant connection to the OUT
142         # channel
143
144         # connect as a client to the cookie unix socket
145         socket_name = os.path.join(SOCKETS_DIR, self.connection_cookie)
146         self.logger.info("IN: connecting to OUT via unix socket '%s'"
147                          % socket_name)
148         sock = socket(AF_UNIX, SOCK_STREAM)
149         connected = False
150         attempt = 0
151         while not connected:
152             try:
153                 attempt = attempt + 1
154                 sock.connect(socket_name)
155                 connected = True
156             except socket_error:
157                 self.logger.info("IN: handling socket.error: %s"
158                                  % str(sys.exc_info()))
159                 if attempt < 10:
160                     self.logger.warn("IN: reattempting to connect to OUT"
161                                      " channel... (%d/10)" % attempt)
162                     sleep(1)
163
164         if connected:
165             self.logger.info("IN: connection succeeded")
166             self.logger.info("IN: sending window size and connection timeout")
167
168             # identify ourselves as the IN proxy
169             sock.sendall(INBOUND_PROXY_ID)
170
171             # send window_size to 256Kib (max size allowed)
172             # and conn_timeout (in seconds, max size allowed)
173             sock.sendall(pack("<ll", (256 * 1024), 14400000))
174
175             # recv oc socket
176             self.oc_conn = receive_socket(sock)
177
178             self.logger.info("IN: oc_conn received (fileno=%d)"
179                              % self.oc_conn.fileno())
180             sock.close()
181         else:
182             self.logger.error("too many failed attempts to establish a"
183                               " connection to OUT channel")
184
185         return connected
186
187     def _runloop(self):
188         self.logger.info("IN: runloop")
189
190         status = True
191         while status:
192             try:
193                 oc_packet = RPCPacket.from_file(self.client_socket,
194                                                 self.logger)
195                 self.bytes_read = self.bytes_read + oc_packet.size
196
197                 self.logger.info("IN: packet headers = "
198                                  + oc_packet.pretty_dump())
199
200                 if isinstance(oc_packet, RPCRTSPacket):
201                     # or oc_packet.header["ptype"] == DCERPC_PKT_AUTH3):
202                     # we do not forward rts packets
203                     self.logger.info("IN: ignored RTS packet")
204                 else:
205                     self.logger.info("IN: sending packet to OC")
206                     self.oc_conn.sendall(oc_packet.data)
207                     self.bytes_written = self.bytes_written + oc_packet.size
208             except IOError:
209                 status = False
210                 # exc = sys.exc_info()
211                 self.logger.error("IN: client connection closed")
212                 self._notify_OUT_channel()
213
214     def _notify_OUT_channel(self):
215         self.logger.info("IN: notifying OUT channel of shutdown")
216
217         socket_name = os.path.join(SOCKETS_DIR, self.connection_cookie)
218         self.logger.info("IN: connecting to OUT via unix socket '%s'"
219                          % socket_name)
220         sock = socket(AF_UNIX, SOCK_STREAM)
221         connected = False
222         attempt = 0
223         while not connected:
224             try:
225                 attempt = attempt + 1
226                 sock.connect(socket_name)
227                 connected = True
228             except socket_error:
229                 self.logger.info("IN: handling socket.error: %s"
230                                  % str(sys.exc_info()))
231                 if attempt < 10:
232                     self.logger.warn("IN: reattempting to connect to OUT"
233                                      " channel... (%d/10)" % attempt)
234                     sleep(1)
235
236         if connected:
237             self.logger.info("IN: connection succeeded")
238             try:
239                 sock.sendall(INBOUND_PROXY_ID + "q")
240                 sock.close()
241             except:
242                 # UNIX socket might already have been closed by OUT channel
243                 pass
244         else:
245             self.logger.error("too many failed attempts to establish a"
246                               " connection to OUT channel")
247
248     def _terminate_oc_socket(self):
249         self.oc_conn.close()
250
251     def sequence(self, environ, start_response):
252         self.logger.info("IN: processing request")
253         if "REMOTE_PORT" in environ:
254             self.logger.info("IN: remote port = %s" % environ["REMOTE_PORT"])
255         # self.logger.info("IN: path: ' + self.path)
256
257         content_length = int(environ["CONTENT_LENGTH"])
258         self.logger.info("IN: request size is %d" % content_length)
259
260         # echo request
261         if content_length <= 0x10:
262             self.logger.info("IN: Exiting (1) from do_RPC_IN_DATA")
263             for data in self.handle_echo_request(environ, start_response):
264                 yield data
265         elif content_length >= 128:
266             self.logger.info("IN: Processing IN channel request")
267
268             self.client_socket = environ["wsgi.input"]
269             self._receive_conn_b1()
270             connected = self._connect_to_OUT_channel()
271
272             if connected:
273                 start_response("200 Success",
274                                [("Content-Type", "application/rpc"),
275                                 ("Content-length", "0")])
276                 self._runloop()
277
278             self._terminate_oc_socket()
279
280             self.log_connection_stats()
281             self.logger.info("IN: Exiting (2) from do_RPC_IN_DATA")
282             
283             # TODO: error handling
284             start_response("200 Success",
285                            [("Content-length", "0"),
286                             ("Content-Type", "application/rpc")])
287             yield ""
288         else:
289             raise Exception("This content-length is not handled")
290
291         # OLD CODE
292         # msg = "RPC_IN_DATA method"
293
294         # content_length = environ["CONTENT_LENGTH"]
295         # # echo request
296         # if content_length <= 10:
297         #     pass
298
299         # start_response("200 OK", [("Content-Type", "text/plain"),
300         #                           ("Content-length", "%s" % len(msg))])
301
302         # return [msg]
303
304 class RPCProxyOutboundChannelHandler(RPCProxyChannelHandler):
305     def __init__(self, logger):
306         RPCProxyChannelHandler.__init__(self, logger)
307         self.unix_socket = None
308         self.oc_conn = None
309         self.in_window_size = 0
310         self.in_conn_timeout = 0
311
312     def _receive_conn_a1(self):
313         # receive the cookie
314         # TODO: validation of CONN/A1
315         self.logger.info("OUT: receiving CONN/A1")
316         packet = RPCPacket.from_file(self.client_socket, self.logger)
317         if not isinstance(packet, RPCRTSPacket):
318             raise Exception("Unexpected non-rts packet received for CONN/A1")
319         self.connection_cookie = str(UUID(bytes=packet.commands[1]["Cookie"]))
320         self.channel_cookie = str(UUID(bytes=packet.commands[2]["Cookie"]))
321
322     def _send_conn_a3(self):
323         self.logger.info("OUT: sending CONN/A3 to client")
324             # send the A3 response to the client
325         packet = RPCRTSOutPacket(self.logger)
326         # we set the min timeout value allowed, as we would actually need
327         # either configuration values from Apache or from some config file
328         packet.add_command(RTS_CMD_CONNECTION_TIMEOUT, 120000)
329         self.bytes_written = self.bytes_written + packet.size
330
331         return packet.make()
332
333     def _send_conn_c2(self):
334         self.logger.info("OUT: sending CONN/C2 to client")
335             # send the C2 response to the client
336         packet = RPCRTSOutPacket(self.logger)
337         # we set the min timeout value allowed, as we would actually need
338         # either configuration values from Apache or from some config file
339         packet.add_command(RTS_CMD_VERSION, 1)
340         packet.add_command(RTS_CMD_RECEIVE_WINDOW_SIZE, self.in_window_size)
341         packet.add_command(RTS_CMD_CONNECTION_TIMEOUT, self.in_conn_timeout)
342         self.bytes_written = self.bytes_written + packet.size
343
344         return packet.make()
345
346     def _setup_oc_socket(self):
347         # create IP connection to OpenChange
348         self.logger.info("OUT: connecting to OC_HOST:1024")
349         connected = False
350         while not connected:
351             try:
352                 oc_conn = socket(AF_INET, SOCK_STREAM)
353                 oc_conn.connect((OC_HOST, 1024))
354                 connected = True
355             except socket_error:
356                 self.logger.info("OUT: failure to connect, retrying...")
357                 sleep(1)
358         self.logger.info("OUT: connection to OC succeeeded (fileno=%d)"
359                          % oc_conn.fileno())
360         self.oc_conn = oc_conn
361
362     def _setup_channel_socket(self):
363         # TODO: add code to create missing socket dir
364         # create the corresponding unix socket
365         socket_name = os.path.join(SOCKETS_DIR, self.connection_cookie)
366         self.logger.info("OUT: creating unix socket '%s'" % socket_name)
367         if os.access(socket_name, os.F_OK):
368             os.remove(socket_name)
369         sock = socket(AF_UNIX, SOCK_STREAM)
370         sock.bind(socket_name)
371         sock.listen(2)
372         self.unix_socket = sock
373
374     def _wait_IN_channel(self):
375         self.logger.info("OUT: waiting for connection from IN")
376         # wait for the IN channel to connect as a B1 should be occurring
377         # on the other side
378         in_sock = self.unix_socket.accept()[0]
379         data = in_sock.recv(2, MSG_WAITALL)
380         if data != INBOUND_PROXY_ID:
381             raise IOError("connection must be from IN proxy (1): /%s/"
382                           % data)
383
384         self.logger.info("OUT: receiving window size + conn_timeout")
385             # receive the WindowSize + ConnectionTimeout
386         (self.in_window_size, self.in_conn_timeout) = \
387             unpack_from("<ll", in_sock.recv(8, MSG_WAITALL))
388             # send OC socket
389         self.logger.info("OUT: sending OC socket to IN")
390         send_socket(in_sock, self.oc_conn)
391         in_sock.close()
392
393     def _runloop(self):
394         self.logger.info("OUT: runloop")
395
396         unix_fd = self.unix_socket.fileno()
397         oc_fd = self.oc_conn.fileno()
398
399         fd_pool = poll()
400         fd_pool.register(unix_fd, POLLIN)
401         fd_pool.register(oc_fd, POLLIN)
402
403         # Listen for data from the listener
404         status = True
405         while status:
406             for data in fd_pool.poll(1000):
407                 fd, event_no = data
408                 if fd == oc_fd:
409                     # self.logger.info("received event '%d' on oc socket"
410                     #                   % event_no)
411                     if event_no & POLLHUP > 0:
412                         # FIXME: notify IN channel?
413                         self.logger.info("OUT: connection closed from OC")
414                         status = False
415                     elif event_no & POLLIN > 0:
416                         oc_packet = RPCPacket.from_file(self.oc_conn,
417                                                         self.logger)
418                         self.logger.info("OUT: packet headers = "
419                                          + oc_packet.pretty_dump())
420                         if isinstance(oc_packet, RPCRTSPacket):
421                             raise Exception("Unexpected rts packet received")
422
423                         self.logger.info("OUT: sending data to client")
424                         self.bytes_read = self.bytes_read + oc_packet.size
425                         self.bytes_written = self.bytes_written + oc_packet.size
426                         yield oc_packet.data
427                         # else:
428                         # self.logger.info("ignored event '%d' on oc socket"
429                         #                  % event_no)
430                 elif fd == unix_fd:
431                     self.logger.info("OUT: ignored event '%d' on unix socket"
432                                      % event_no)
433                     # FIXME: we should listen to what the IN channel has to say
434                     status = False
435                 else:
436                     raise Exception("invalid poll event: %s" % str(data))
437             # write(oc_packet.header_data)
438             # write(oc_packet.data)
439             # self.logger.info("OUT: data sent to client")
440
441     def _terminate_sockets(self):
442         socket_name = os.path.join(SOCKETS_DIR, self.connection_cookie)
443         self.logger.info("OUT: removing and closing unix socket '%s'"
444                          % socket_name)
445         if os.access(socket_name, os.F_OK):
446             os.remove(socket_name)
447         self.unix_socket.close()
448         self.oc_conn.close()
449
450     def sequence(self, environ, start_response):
451         self.logger.info("OUT: processing request")
452         if "REMOTE_PORT" in environ:
453             self.logger.info("OUT: remote port = %s" % environ["REMOTE_PORT"])
454         # self.logger.info("OUT: path: ' + self.path)
455         content_length = int(environ["CONTENT_LENGTH"])
456         self.logger.info("OUT: request size is %d" % content_length)
457
458         if content_length <= 0x10:
459             # echo request
460             for data in self.handle_echo_request(environ, start_response):
461                 yield data
462         elif content_length == 76:
463             self.logger.info("OUT: Processing nonreplacement Out channel"
464                              "request")
465
466             self.client_socket = environ["wsgi.input"]
467             self._receive_conn_a1()
468
469             # Content-length = 1 Gib
470             start_response("200 Success",
471                            [("Content-Type", "application/rpc"),
472                             ("Content-length", "%d" % (1024 ** 3))])
473
474             yield self._send_conn_a3()
475             self._setup_oc_socket()
476             self._setup_channel_socket()
477             self._wait_IN_channel()
478
479             yield self._send_conn_c2()
480             self.logger.info("OUT: total bytes sent yet: %d"
481                              % self.bytes_written)
482             for data in self._runloop():
483                 yield data
484             self._terminate_sockets()
485         elif content_length == 120:
486             # Out channel request: replacement OUT channel
487             raise Exception("Replacement OUT channel request not handled")
488         else:
489             raise Exception("This content-length is not handled")
490
491         self.log_connection_stats()
492         self.logger.info("OUT: Exiting from do_RPC_OUT_DATA")