b912c5b1b7c030e30e04b8c0487df65be5e029e0
[bbaumbach/samba.git] / third_party / waf / waflib / extras / prefork.py
1 #! /usr/bin/env python
2 # encoding: utf-8
3 # Thomas Nagy, 2015 (ita)
4
5 """
6 Execute commands through pre-forked servers. This tool creates as many servers as build threads.
7 On a benchmark executed on Linux Kubuntu 14, 8 virtual cores and SSD drive::
8
9     ./genbench.py /tmp/build 200 100 15 5
10     waf clean build -j24
11     # no prefork: 2m7.179s
12     # prefork:    0m55.400s
13
14 To use::
15
16     def options(opt):
17         # optional, will spawn 40 servers early
18         opt.load('prefork')
19
20     def build(bld):
21         bld.load('prefork')
22         ...
23         more code
24
25 The servers and the build process are using a shared nonce to prevent undesirable external connections.
26 """
27
28 import os, re, socket, threading, sys, subprocess, time, atexit, traceback, random, signal
29 try:
30         import SocketServer
31 except ImportError:
32         import socketserver as SocketServer
33 try:
34         from queue import Queue
35 except ImportError:
36         from Queue import Queue
37 try:
38         import cPickle
39 except ImportError:
40         import pickle as cPickle
41
42 SHARED_KEY = None
43 HEADER_SIZE = 64
44
45 REQ = 'REQ'
46 RES = 'RES'
47 BYE = 'BYE'
48
49 def make_header(params, cookie=''):
50         header = ','.join(params)
51         header = header.ljust(HEADER_SIZE - len(cookie))
52         assert(len(header) == HEADER_SIZE - len(cookie))
53         header = header + cookie
54         if sys.hexversion > 0x3000000:
55                 header = header.encode('iso8859-1')
56         return header
57
58 def safe_compare(x, y):
59         sum = 0
60         for (a, b) in zip(x, y):
61                 sum |= ord(a) ^ ord(b)
62         return sum == 0
63
64 re_valid_query = re.compile('^[a-zA-Z0-9_, ]+$')
65 class req(SocketServer.StreamRequestHandler):
66         def handle(self):
67                 try:
68                         while self.process_command():
69                                 pass
70                 except KeyboardInterrupt:
71                         return
72                 except Exception as e:
73                         print(e)
74
75         def send_response(self, ret, out, err, exc):
76                 if out or err or exc:
77                         data = (out, err, exc)
78                         data = cPickle.dumps(data, -1)
79                 else:
80                         data = ''
81
82                 params = [RES, str(ret), str(len(data))]
83
84                 # no need for the cookie in the response
85                 self.wfile.write(make_header(params))
86                 if data:
87                         self.wfile.write(data)
88                 self.wfile.flush()
89
90         def process_command(self):
91                 query = self.rfile.read(HEADER_SIZE)
92                 if not query:
93                         return None
94                 #print(len(query))
95                 assert(len(query) == HEADER_SIZE)
96                 if sys.hexversion > 0x3000000:
97                         query = query.decode('iso8859-1')
98
99                 # magic cookie
100                 key = query[-20:]
101                 if not safe_compare(key, SHARED_KEY):
102                         print('%r %r' % (key, SHARED_KEY))
103                         self.send_response(-1, '', '', 'Invalid key given!')
104                         return 'meh'
105
106                 query = query[:-20]
107                 #print "%r" % query
108                 if not re_valid_query.match(query):
109                         self.send_response(-1, '', '', 'Invalid query %r' % query)
110                         raise ValueError('Invalid query %r' % query)
111
112                 query = query.strip().split(',')
113
114                 if query[0] == REQ:
115                         self.run_command(query[1:])
116                 elif query[0] == BYE:
117                         raise ValueError('Exit')
118                 else:
119                         raise ValueError('Invalid query %r' % query)
120                 return 'ok'
121
122         def run_command(self, query):
123
124                 size = int(query[0])
125                 data = self.rfile.read(size)
126                 assert(len(data) == size)
127                 kw = cPickle.loads(data)
128
129                 # run command
130                 ret = out = err = exc = None
131                 cmd = kw['cmd']
132                 del kw['cmd']
133                 #print(cmd)
134
135                 try:
136                         if kw['stdout'] or kw['stderr']:
137                                 p = subprocess.Popen(cmd, **kw)
138                                 (out, err) = p.communicate()
139                                 ret = p.returncode
140                         else:
141                                 ret = subprocess.Popen(cmd, **kw).wait()
142                 except KeyboardInterrupt:
143                         raise
144                 except Exception as e:
145                         ret = -1
146                         exc = str(e) + traceback.format_exc()
147
148                 self.send_response(ret, out, err, exc)
149
150 def create_server(conn, cls):
151         # child processes do not need the key, so we remove it from the OS environment
152         global SHARED_KEY
153         SHARED_KEY = os.environ['SHARED_KEY']
154         os.environ['SHARED_KEY'] = ''
155
156         ppid = int(os.environ['PREFORKPID'])
157         def reap():
158                 if os.sep != '/':
159                         os.waitpid(ppid, 0)
160                 else:
161                         while 1:
162                                 try:
163                                         os.kill(ppid, 0)
164                                 except OSError:
165                                         break
166                                 else:
167                                         time.sleep(1)
168                 os.kill(os.getpid(), signal.SIGKILL)
169         t = threading.Thread(target=reap)
170         t.setDaemon(True)
171         t.start()
172
173         server = SocketServer.TCPServer(conn, req)
174         print(server.server_address[1])
175         sys.stdout.flush()
176         #server.timeout = 6000 # seconds
177         server.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
178         try:
179                 server.serve_forever(poll_interval=0.001)
180         except KeyboardInterrupt:
181                 pass
182
183 if __name__ == '__main__':
184         conn = ("127.0.0.1", 0)
185         #print("listening - %r %r\n" % conn)
186         create_server(conn, req)
187 else:
188
189         from waflib import Logs, Utils, Runner, Errors, Options
190
191         def init_task_pool(self):
192                 # lazy creation, and set a common pool for all task consumers
193                 pool = self.pool = []
194                 for i in range(self.numjobs):
195                         consumer = Runner.get_pool()
196                         pool.append(consumer)
197                         consumer.idx = i
198                 self.ready = Queue(0)
199                 def setq(consumer):
200                         consumer.ready = self.ready
201                         try:
202                                 threading.current_thread().idx = consumer.idx
203                         except Exception as e:
204                                 print(e)
205                 for x in pool:
206                         x.ready.put(setq)
207                 return pool
208         Runner.Parallel.init_task_pool = init_task_pool
209
210         def make_server(bld, idx):
211                 cmd = [sys.executable, os.path.abspath(__file__)]
212                 proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
213                 return proc
214
215         def make_conn(bld, srv):
216                 port = srv.port
217                 conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
218                 conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
219                 conn.connect(('127.0.0.1', port))
220                 return conn
221
222
223         SERVERS = []
224         CONNS = []
225         def close_all():
226                 global SERVERS, CONNS
227                 while CONNS:
228                         conn = CONNS.pop()
229                         try:
230                                 conn.close()
231                         except:
232                                 pass
233                 while SERVERS:
234                         srv = SERVERS.pop()
235                         try:
236                                 srv.kill()
237                         except:
238                                 pass
239         atexit.register(close_all)
240
241         def put_data(conn, data):
242                 cnt = 0
243                 while cnt < len(data):
244                         sent = conn.send(data[cnt:])
245                         if sent == 0:
246                                 raise RuntimeError('connection ended')
247                         cnt += sent
248
249         def read_data(conn, siz):
250                 cnt = 0
251                 buf = []
252                 while cnt < siz:
253                         data = conn.recv(min(siz - cnt, 1024))
254                         if not data:
255                                 raise RuntimeError('connection ended %r %r' % (cnt, siz))
256                         buf.append(data)
257                         cnt += len(data)
258                 if sys.hexversion > 0x3000000:
259                         ret = ''.encode('iso8859-1').join(buf)
260                 else:
261                         ret = ''.join(buf)
262                 return ret
263
264         def exec_command(self, cmd, **kw):
265                 if 'stdout' in kw:
266                         if kw['stdout'] not in (None, subprocess.PIPE):
267                                 return self.exec_command_old(cmd, **kw)
268                 elif 'stderr' in kw:
269                         if kw['stderr'] not in (None, subprocess.PIPE):
270                                 return self.exec_command_old(cmd, **kw)
271
272                 kw['shell'] = isinstance(cmd, str)
273                 Logs.debug('runner: %r' % cmd)
274                 Logs.debug('runner_env: kw=%s' % kw)
275
276                 if self.logger:
277                         self.logger.info(cmd)
278
279                 if 'stdout' not in kw:
280                         kw['stdout'] = subprocess.PIPE
281                 if 'stderr' not in kw:
282                         kw['stderr'] = subprocess.PIPE
283
284                 if Logs.verbose and not kw['shell'] and not Utils.check_exe(cmd[0]):
285                         raise Errors.WafError("Program %s not found!" % cmd[0])
286
287                 idx = threading.current_thread().idx
288                 kw['cmd'] = cmd
289
290                 # serialization..
291                 #print("sub %r %r" % (idx, cmd))
292                 #print("write to %r %r" % (idx, cmd))
293
294                 data = cPickle.dumps(kw, -1)
295                 params = [REQ, str(len(data))]
296                 header = make_header(params, self.SHARED_KEY)
297
298                 conn = CONNS[idx]
299
300                 put_data(conn, header + data)
301                 #put_data(conn, data)
302
303                 #print("running %r %r" % (idx, cmd))
304                 #print("read from %r %r" % (idx, cmd))
305
306                 data = read_data(conn, HEADER_SIZE)
307                 if sys.hexversion > 0x3000000:
308                         data = data.decode('iso8859-1')
309
310                 #print("received %r" % data)
311                 lst = data.split(',')
312                 ret = int(lst[1])
313                 dlen = int(lst[2])
314
315                 out = err = None
316                 if dlen:
317                         data = read_data(conn, dlen)
318                         (out, err, exc) = cPickle.loads(data)
319                         if exc:
320                                 raise Errors.WafError('Execution failure: %s' % exc)
321
322                 if out:
323                         if not isinstance(out, str):
324                                 out = out.decode(sys.stdout.encoding or 'iso8859-1')
325                         if self.logger:
326                                 self.logger.debug('out: %s' % out)
327                         else:
328                                 Logs.info(out, extra={'stream':sys.stdout, 'c1': ''})
329                 if err:
330                         if not isinstance(err, str):
331                                 err = err.decode(sys.stdout.encoding or 'iso8859-1')
332                         if self.logger:
333                                 self.logger.error('err: %s' % err)
334                         else:
335                                 Logs.info(err, extra={'stream':sys.stderr, 'c1': ''})
336
337                 return ret
338
339         def init_key(ctx):
340                 try:
341                         key = ctx.SHARED_KEY = os.environ['SHARED_KEY']
342                 except KeyError:
343                         key = "".join([chr(random.SystemRandom().randint(40, 126)) for x in range(20)])
344                         os.environ['SHARED_KEY'] = ctx.SHARED_KEY = key
345
346                 os.environ['PREFORKPID'] = str(os.getpid())
347                 return key
348
349         def init_servers(ctx, maxval):
350                 while len(SERVERS) < maxval:
351                         i = len(SERVERS)
352                         srv = make_server(ctx, i)
353                         SERVERS.append(srv)
354                 while len(CONNS) < maxval:
355                         i = len(CONNS)
356                         srv = SERVERS[i]
357
358                         # postpone the connection
359                         srv.port = int(srv.stdout.readline())
360
361                         conn = None
362                         for x in range(30):
363                                 try:
364                                         conn = make_conn(ctx, srv)
365                                         break
366                                 except socket.error:
367                                         time.sleep(0.01)
368                         if not conn:
369                                 raise ValueError('Could not start the server!')
370                         if srv.poll() is not None:
371                                 Logs.warn('Looks like it it not our server process - concurrent builds are unsupported at this stage')
372                                 raise ValueError('Could not start the server')
373                         CONNS.append(conn)
374
375         def init_smp(self):
376                 if not getattr(Options.options, 'smp', getattr(self, 'smp', None)):
377                         return
378                 if Utils.unversioned_sys_platform() in ('freebsd',):
379                         pid = os.getpid()
380                         cmd = ['cpuset', '-l', '0', '-p', str(pid)]
381                 elif Utils.unversioned_sys_platform() in ('linux',):
382                         pid = os.getpid()
383                         cmd = ['taskset', '-pc', '0', str(pid)]
384                 if cmd:
385                         self.cmd_and_log(cmd, quiet=0)
386
387         def options(opt):
388                 init_key(opt)
389                 init_servers(opt, 40)
390                 opt.add_option('--pin-process', action='store_true', dest='smp', default=False)
391
392         def build(bld):
393                 if bld.cmd == 'clean':
394                         return
395
396                 init_key(bld)
397                 init_servers(bld, bld.jobs)
398                 init_smp(bld)
399
400                 bld.__class__.exec_command_old = bld.__class__.exec_command
401                 bld.__class__.exec_command = exec_command