tools: ynl: move the new line in NlMsg __repr__
[sfrench/cifs-2.6.git] / tools / net / ynl / lib / ynl.py
1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3 from collections import namedtuple
4 import functools
5 import os
6 import random
7 import socket
8 import struct
9 from struct import Struct
10 import yaml
11 import ipaddress
12 import uuid
13
14 from .nlspec import SpecFamily
15
16 #
17 # Generic Netlink code which should really be in some library, but I can't quickly find one.
18 #
19
20
21 class Netlink:
22     # Netlink socket
23     SOL_NETLINK = 270
24
25     NETLINK_ADD_MEMBERSHIP = 1
26     NETLINK_CAP_ACK = 10
27     NETLINK_EXT_ACK = 11
28     NETLINK_GET_STRICT_CHK = 12
29
30     # Netlink message
31     NLMSG_ERROR = 2
32     NLMSG_DONE = 3
33
34     NLM_F_REQUEST = 1
35     NLM_F_ACK = 4
36     NLM_F_ROOT = 0x100
37     NLM_F_MATCH = 0x200
38
39     NLM_F_REPLACE = 0x100
40     NLM_F_EXCL = 0x200
41     NLM_F_CREATE = 0x400
42     NLM_F_APPEND = 0x800
43
44     NLM_F_CAPPED = 0x100
45     NLM_F_ACK_TLVS = 0x200
46
47     NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
48
49     NLA_F_NESTED = 0x8000
50     NLA_F_NET_BYTEORDER = 0x4000
51
52     NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
53
54     # Genetlink defines
55     NETLINK_GENERIC = 16
56
57     GENL_ID_CTRL = 0x10
58
59     # nlctrl
60     CTRL_CMD_GETFAMILY = 3
61
62     CTRL_ATTR_FAMILY_ID = 1
63     CTRL_ATTR_FAMILY_NAME = 2
64     CTRL_ATTR_MAXATTR = 5
65     CTRL_ATTR_MCAST_GROUPS = 7
66
67     CTRL_ATTR_MCAST_GRP_NAME = 1
68     CTRL_ATTR_MCAST_GRP_ID = 2
69
70     # Extack types
71     NLMSGERR_ATTR_MSG = 1
72     NLMSGERR_ATTR_OFFS = 2
73     NLMSGERR_ATTR_COOKIE = 3
74     NLMSGERR_ATTR_POLICY = 4
75     NLMSGERR_ATTR_MISS_TYPE = 5
76     NLMSGERR_ATTR_MISS_NEST = 6
77
78
79 class NlError(Exception):
80   def __init__(self, nl_msg):
81     self.nl_msg = nl_msg
82
83   def __str__(self):
84     return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
85
86
87 class NlAttr:
88     ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
89     type_formats = {
90         'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
91         's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
92         'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
93         's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
94         'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
95         's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
96         'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
97         's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
98     }
99
100     def __init__(self, raw, offset):
101         self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
102         self.type = self._type & ~Netlink.NLA_TYPE_MASK
103         self.is_nest = self._type & Netlink.NLA_F_NESTED
104         self.payload_len = self._len
105         self.full_len = (self.payload_len + 3) & ~3
106         self.raw = raw[offset + 4 : offset + self.payload_len]
107
108     @classmethod
109     def get_format(cls, attr_type, byte_order=None):
110         format = cls.type_formats[attr_type]
111         if byte_order:
112             return format.big if byte_order == "big-endian" \
113                 else format.little
114         return format.native
115
116     def as_scalar(self, attr_type, byte_order=None):
117         format = self.get_format(attr_type, byte_order)
118         return format.unpack(self.raw)[0]
119
120     def as_auto_scalar(self, attr_type, byte_order=None):
121         if len(self.raw) != 4 and len(self.raw) != 8:
122             raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
123         real_type = attr_type[0] + str(len(self.raw) * 8)
124         format = self.get_format(real_type, byte_order)
125         return format.unpack(self.raw)[0]
126
127     def as_strz(self):
128         return self.raw.decode('ascii')[:-1]
129
130     def as_bin(self):
131         return self.raw
132
133     def as_c_array(self, type):
134         format = self.get_format(type)
135         return [ x[0] for x in format.iter_unpack(self.raw) ]
136
137     def __repr__(self):
138         return f"[type:{self.type} len:{self._len}] {self.raw}"
139
140
141 class NlAttrs:
142     def __init__(self, msg, offset=0):
143         self.attrs = []
144
145         while offset < len(msg):
146             attr = NlAttr(msg, offset)
147             offset += attr.full_len
148             self.attrs.append(attr)
149
150     def __iter__(self):
151         yield from self.attrs
152
153     def __repr__(self):
154         msg = ''
155         for a in self.attrs:
156             if msg:
157                 msg += '\n'
158             msg += repr(a)
159         return msg
160
161
162 class NlMsg:
163     def __init__(self, msg, offset, attr_space=None):
164         self.hdr = msg[offset : offset + 16]
165
166         self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
167             struct.unpack("IHHII", self.hdr)
168
169         self.raw = msg[offset + 16 : offset + self.nl_len]
170
171         self.error = 0
172         self.done = 0
173
174         extack_off = None
175         if self.nl_type == Netlink.NLMSG_ERROR:
176             self.error = struct.unpack("i", self.raw[0:4])[0]
177             self.done = 1
178             extack_off = 20
179         elif self.nl_type == Netlink.NLMSG_DONE:
180             self.done = 1
181             extack_off = 4
182
183         self.extack = None
184         if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
185             self.extack = dict()
186             extack_attrs = NlAttrs(self.raw[extack_off:])
187             for extack in extack_attrs:
188                 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
189                     self.extack['msg'] = extack.as_strz()
190                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
191                     self.extack['miss-type'] = extack.as_scalar('u32')
192                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
193                     self.extack['miss-nest'] = extack.as_scalar('u32')
194                 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
195                     self.extack['bad-attr-offs'] = extack.as_scalar('u32')
196                 else:
197                     if 'unknown' not in self.extack:
198                         self.extack['unknown'] = []
199                     self.extack['unknown'].append(extack)
200
201             if attr_space:
202                 # We don't have the ability to parse nests yet, so only do global
203                 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
204                     miss_type = self.extack['miss-type']
205                     if miss_type in attr_space.attrs_by_val:
206                         spec = attr_space.attrs_by_val[miss_type]
207                         desc = spec['name']
208                         if 'doc' in spec:
209                             desc += f" ({spec['doc']})"
210                         self.extack['miss-type'] = desc
211
212     def cmd(self):
213         return self.nl_type
214
215     def __repr__(self):
216         msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
217         if self.error:
218             msg += '\n\terror: ' + str(self.error)
219         if self.extack:
220             msg += '\n\textack: ' + repr(self.extack)
221         return msg
222
223
224 class NlMsgs:
225     def __init__(self, data, attr_space=None):
226         self.msgs = []
227
228         offset = 0
229         while offset < len(data):
230             msg = NlMsg(data, offset, attr_space=attr_space)
231             offset += msg.nl_len
232             self.msgs.append(msg)
233
234     def __iter__(self):
235         yield from self.msgs
236
237
238 genl_family_name_to_id = None
239
240
241 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
242     # we prepend length in _genl_msg_finalize()
243     if seq is None:
244         seq = random.randint(1, 1024)
245     nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
246     genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
247     return nlmsg + genlmsg
248
249
250 def _genl_msg_finalize(msg):
251     return struct.pack("I", len(msg) + 4) + msg
252
253
254 def _genl_load_families():
255     with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
256         sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
257
258         msg = _genl_msg(Netlink.GENL_ID_CTRL,
259                         Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
260                         Netlink.CTRL_CMD_GETFAMILY, 1)
261         msg = _genl_msg_finalize(msg)
262
263         sock.send(msg, 0)
264
265         global genl_family_name_to_id
266         genl_family_name_to_id = dict()
267
268         while True:
269             reply = sock.recv(128 * 1024)
270             nms = NlMsgs(reply)
271             for nl_msg in nms:
272                 if nl_msg.error:
273                     print("Netlink error:", nl_msg.error)
274                     return
275                 if nl_msg.done:
276                     return
277
278                 gm = GenlMsg(nl_msg)
279                 fam = dict()
280                 for attr in NlAttrs(gm.raw):
281                     if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
282                         fam['id'] = attr.as_scalar('u16')
283                     elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
284                         fam['name'] = attr.as_strz()
285                     elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
286                         fam['maxattr'] = attr.as_scalar('u32')
287                     elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
288                         fam['mcast'] = dict()
289                         for entry in NlAttrs(attr.raw):
290                             mcast_name = None
291                             mcast_id = None
292                             for entry_attr in NlAttrs(entry.raw):
293                                 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
294                                     mcast_name = entry_attr.as_strz()
295                                 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
296                                     mcast_id = entry_attr.as_scalar('u32')
297                             if mcast_name and mcast_id is not None:
298                                 fam['mcast'][mcast_name] = mcast_id
299                 if 'name' in fam and 'id' in fam:
300                     genl_family_name_to_id[fam['name']] = fam
301
302
303 class GenlMsg:
304     def __init__(self, nl_msg):
305         self.nl = nl_msg
306         self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
307         self.raw = nl_msg.raw[4:]
308
309     def cmd(self):
310         return self.genl_cmd
311
312     def __repr__(self):
313         msg = repr(self.nl)
314         msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
315         for a in self.raw_attrs:
316             msg += '\t\t' + repr(a) + '\n'
317         return msg
318
319
320 class NetlinkProtocol:
321     def __init__(self, family_name, proto_num):
322         self.family_name = family_name
323         self.proto_num = proto_num
324
325     def _message(self, nl_type, nl_flags, seq=None):
326         if seq is None:
327             seq = random.randint(1, 1024)
328         nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
329         return nlmsg
330
331     def message(self, flags, command, version, seq=None):
332         return self._message(command, flags, seq)
333
334     def _decode(self, nl_msg):
335         return nl_msg
336
337     def decode(self, ynl, nl_msg):
338         msg = self._decode(nl_msg)
339         fixed_header_size = 0
340         if ynl:
341             op = ynl.rsp_by_value[msg.cmd()]
342             fixed_header_size = ynl._struct_size(op.fixed_header)
343         msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
344         return msg
345
346     def get_mcast_id(self, mcast_name, mcast_groups):
347         if mcast_name not in mcast_groups:
348             raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
349         return mcast_groups[mcast_name].value
350
351
352 class GenlProtocol(NetlinkProtocol):
353     def __init__(self, family_name):
354         super().__init__(family_name, Netlink.NETLINK_GENERIC)
355
356         global genl_family_name_to_id
357         if genl_family_name_to_id is None:
358             _genl_load_families()
359
360         self.genl_family = genl_family_name_to_id[family_name]
361         self.family_id = genl_family_name_to_id[family_name]['id']
362
363     def message(self, flags, command, version, seq=None):
364         nlmsg = self._message(self.family_id, flags, seq)
365         genlmsg = struct.pack("BBH", command, version, 0)
366         return nlmsg + genlmsg
367
368     def _decode(self, nl_msg):
369         return GenlMsg(nl_msg)
370
371     def get_mcast_id(self, mcast_name, mcast_groups):
372         if mcast_name not in self.genl_family['mcast']:
373             raise Exception(f'Multicast group "{mcast_name}" not present in the family')
374         return self.genl_family['mcast'][mcast_name]
375
376
377
378 class SpaceAttrs:
379     SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
380
381     def __init__(self, attr_space, attrs, outer = None):
382         outer_scopes = outer.scopes if outer else []
383         inner_scope = self.SpecValuesPair(attr_space, attrs)
384         self.scopes = [inner_scope] + outer_scopes
385
386     def lookup(self, name):
387         for scope in self.scopes:
388             if name in scope.spec:
389                 if name in scope.values:
390                     return scope.values[name]
391                 spec_name = scope.spec.yaml['name']
392                 raise Exception(
393                     f"No value for '{name}' in attribute space '{spec_name}'")
394         raise Exception(f"Attribute '{name}' not defined in any attribute-set")
395
396
397 #
398 # YNL implementation details.
399 #
400
401
402 class YnlFamily(SpecFamily):
403     def __init__(self, def_path, schema=None, process_unknown=False):
404         super().__init__(def_path, schema)
405
406         self.include_raw = False
407         self.process_unknown = process_unknown
408
409         try:
410             if self.proto == "netlink-raw":
411                 self.nlproto = NetlinkProtocol(self.yaml['name'],
412                                                self.yaml['protonum'])
413             else:
414                 self.nlproto = GenlProtocol(self.yaml['name'])
415         except KeyError:
416             raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
417
418         self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
419         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
420         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
421         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
422
423         self.async_msg_ids = set()
424         self.async_msg_queue = []
425
426         for msg in self.msgs.values():
427             if msg.is_async:
428                 self.async_msg_ids.add(msg.rsp_value)
429
430         for op_name, op in self.ops.items():
431             bound_f = functools.partial(self._op, op_name)
432             setattr(self, op.ident_name, bound_f)
433
434
435     def ntf_subscribe(self, mcast_name):
436         mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
437         self.sock.bind((0, 0))
438         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
439                              mcast_id)
440
441     def _encode_enum(self, attr_spec, value):
442         enum = self.consts[attr_spec['enum']]
443         if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
444             scalar = 0
445             if isinstance(value, str):
446                 value = [value]
447             for single_value in value:
448                 scalar += enum.entries[single_value].user_value(as_flags = True)
449             return scalar
450         else:
451             return enum.entries[value].user_value()
452
453     def _get_scalar(self, attr_spec, value):
454         try:
455             return int(value)
456         except (ValueError, TypeError) as e:
457             if 'enum' not in attr_spec:
458                 raise e
459         return self._encode_enum(attr_spec, value);
460
461     def _add_attr(self, space, name, value, search_attrs):
462         try:
463             attr = self.attr_sets[space][name]
464         except KeyError:
465             raise Exception(f"Space '{space}' has no attribute '{name}'")
466         nl_type = attr.value
467
468         if attr.is_multi and isinstance(value, list):
469             attr_payload = b''
470             for subvalue in value:
471                 attr_payload += self._add_attr(space, name, subvalue, search_attrs)
472             return attr_payload
473
474         if attr["type"] == 'nest':
475             nl_type |= Netlink.NLA_F_NESTED
476             attr_payload = b''
477             sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
478             for subname, subvalue in value.items():
479                 attr_payload += self._add_attr(attr['nested-attributes'],
480                                                subname, subvalue, sub_attrs)
481         elif attr["type"] == 'flag':
482             if not value:
483                 # If value is absent or false then skip attribute creation.
484                 return b''
485             attr_payload = b''
486         elif attr["type"] == 'string':
487             attr_payload = str(value).encode('ascii') + b'\x00'
488         elif attr["type"] == 'binary':
489             if isinstance(value, bytes):
490                 attr_payload = value
491             elif isinstance(value, str):
492                 attr_payload = bytes.fromhex(value)
493             elif isinstance(value, dict) and attr.struct_name:
494                 attr_payload = self._encode_struct(attr.struct_name, value)
495             else:
496                 raise Exception(f'Unknown type for binary attribute, value: {value}')
497         elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
498             scalar = self._get_scalar(attr, value)
499             if attr.is_auto_scalar:
500                 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
501             else:
502                 attr_type = attr["type"]
503             format = NlAttr.get_format(attr_type, attr.byte_order)
504             attr_payload = format.pack(scalar)
505         elif attr['type'] in "bitfield32":
506             scalar_value = self._get_scalar(attr, value["value"])
507             scalar_selector = self._get_scalar(attr, value["selector"])
508             attr_payload = struct.pack("II", scalar_value, scalar_selector)
509         elif attr['type'] == 'sub-message':
510             msg_format = self._resolve_selector(attr, search_attrs)
511             attr_payload = b''
512             if msg_format.fixed_header:
513                 attr_payload += self._encode_struct(msg_format.fixed_header, value)
514             if msg_format.attr_set:
515                 if msg_format.attr_set in self.attr_sets:
516                     nl_type |= Netlink.NLA_F_NESTED
517                     sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
518                     for subname, subvalue in value.items():
519                         attr_payload += self._add_attr(msg_format.attr_set,
520                                                        subname, subvalue, sub_attrs)
521                 else:
522                     raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
523         else:
524             raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
525
526         pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
527         return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
528
529     def _decode_enum(self, raw, attr_spec):
530         enum = self.consts[attr_spec['enum']]
531         if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
532             i = 0
533             value = set()
534             while raw:
535                 if raw & 1:
536                     value.add(enum.entries_by_val[i].name)
537                 raw >>= 1
538                 i += 1
539         else:
540             value = enum.entries_by_val[raw].name
541         return value
542
543     def _decode_binary(self, attr, attr_spec):
544         if attr_spec.struct_name:
545             decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
546         elif attr_spec.sub_type:
547             decoded = attr.as_c_array(attr_spec.sub_type)
548         else:
549             decoded = attr.as_bin()
550             if attr_spec.display_hint:
551                 decoded = self._formatted_string(decoded, attr_spec.display_hint)
552         return decoded
553
554     def _decode_array_nest(self, attr, attr_spec):
555         decoded = []
556         offset = 0
557         while offset < len(attr.raw):
558             item = NlAttr(attr.raw, offset)
559             offset += item.full_len
560
561             subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
562             decoded.append({ item.type: subattrs })
563         return decoded
564
565     def _decode_unknown(self, attr):
566         if attr.is_nest:
567             return self._decode(NlAttrs(attr.raw), None)
568         else:
569             return attr.as_bin()
570
571     def _rsp_add(self, rsp, name, is_multi, decoded):
572         if is_multi == None:
573             if name in rsp and type(rsp[name]) is not list:
574                 rsp[name] = [rsp[name]]
575                 is_multi = True
576             else:
577                 is_multi = False
578
579         if not is_multi:
580             rsp[name] = decoded
581         elif name in rsp:
582             rsp[name].append(decoded)
583         else:
584             rsp[name] = [decoded]
585
586     def _resolve_selector(self, attr_spec, search_attrs):
587         sub_msg = attr_spec.sub_message
588         if sub_msg not in self.sub_msgs:
589             raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
590         sub_msg_spec = self.sub_msgs[sub_msg]
591
592         selector = attr_spec.selector
593         value = search_attrs.lookup(selector)
594         if value not in sub_msg_spec.formats:
595             raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
596
597         spec = sub_msg_spec.formats[value]
598         return spec
599
600     def _decode_sub_msg(self, attr, attr_spec, search_attrs):
601         msg_format = self._resolve_selector(attr_spec, search_attrs)
602         decoded = {}
603         offset = 0
604         if msg_format.fixed_header:
605             decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
606             offset = self._struct_size(msg_format.fixed_header)
607         if msg_format.attr_set:
608             if msg_format.attr_set in self.attr_sets:
609                 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
610                 decoded.update(subdict)
611             else:
612                 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
613         return decoded
614
615     def _decode(self, attrs, space, outer_attrs = None):
616         rsp = dict()
617         if space:
618             attr_space = self.attr_sets[space]
619             search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
620
621         for attr in attrs:
622             try:
623                 attr_spec = attr_space.attrs_by_val[attr.type]
624             except (KeyError, UnboundLocalError):
625                 if not self.process_unknown:
626                     raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
627                 attr_name = f"UnknownAttr({attr.type})"
628                 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
629                 continue
630
631             if attr_spec["type"] == 'nest':
632                 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
633                 decoded = subdict
634             elif attr_spec["type"] == 'string':
635                 decoded = attr.as_strz()
636             elif attr_spec["type"] == 'binary':
637                 decoded = self._decode_binary(attr, attr_spec)
638             elif attr_spec["type"] == 'flag':
639                 decoded = True
640             elif attr_spec.is_auto_scalar:
641                 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
642             elif attr_spec["type"] in NlAttr.type_formats:
643                 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
644                 if 'enum' in attr_spec:
645                     decoded = self._decode_enum(decoded, attr_spec)
646             elif attr_spec["type"] == 'array-nest':
647                 decoded = self._decode_array_nest(attr, attr_spec)
648             elif attr_spec["type"] == 'bitfield32':
649                 value, selector = struct.unpack("II", attr.raw)
650                 if 'enum' in attr_spec:
651                     value = self._decode_enum(value, attr_spec)
652                     selector = self._decode_enum(selector, attr_spec)
653                 decoded = {"value": value, "selector": selector}
654             elif attr_spec["type"] == 'sub-message':
655                 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
656             else:
657                 if not self.process_unknown:
658                     raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
659                 decoded = self._decode_unknown(attr)
660
661             self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
662
663         return rsp
664
665     def _decode_extack_path(self, attrs, attr_set, offset, target):
666         for attr in attrs:
667             try:
668                 attr_spec = attr_set.attrs_by_val[attr.type]
669             except KeyError:
670                 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
671             if offset > target:
672                 break
673             if offset == target:
674                 return '.' + attr_spec.name
675
676             if offset + attr.full_len <= target:
677                 offset += attr.full_len
678                 continue
679             if attr_spec['type'] != 'nest':
680                 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
681             offset += 4
682             subpath = self._decode_extack_path(NlAttrs(attr.raw),
683                                                self.attr_sets[attr_spec['nested-attributes']],
684                                                offset, target)
685             if subpath is None:
686                 return None
687             return '.' + attr_spec.name + subpath
688
689         return None
690
691     def _decode_extack(self, request, op, extack):
692         if 'bad-attr-offs' not in extack:
693             return
694
695         msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
696         offset = 20 + self._struct_size(op.fixed_header)
697         path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
698                                         extack['bad-attr-offs'])
699         if path:
700             del extack['bad-attr-offs']
701             extack['bad-attr'] = path
702
703     def _struct_size(self, name):
704         if name:
705             members = self.consts[name].members
706             size = 0
707             for m in members:
708                 if m.type in ['pad', 'binary']:
709                     if m.struct:
710                         size += self._struct_size(m.struct)
711                     else:
712                         size += m.len
713                 else:
714                     format = NlAttr.get_format(m.type, m.byte_order)
715                     size += format.size
716             return size
717         else:
718             return 0
719
720     def _decode_struct(self, data, name):
721         members = self.consts[name].members
722         attrs = dict()
723         offset = 0
724         for m in members:
725             value = None
726             if m.type == 'pad':
727                 offset += m.len
728             elif m.type == 'binary':
729                 if m.struct:
730                     len = self._struct_size(m.struct)
731                     value = self._decode_struct(data[offset : offset + len],
732                                                 m.struct)
733                     offset += len
734                 else:
735                     value = data[offset : offset + m.len]
736                     offset += m.len
737             else:
738                 format = NlAttr.get_format(m.type, m.byte_order)
739                 [ value ] = format.unpack_from(data, offset)
740                 offset += format.size
741             if value is not None:
742                 if m.enum:
743                     value = self._decode_enum(value, m)
744                 elif m.display_hint:
745                     value = self._formatted_string(value, m.display_hint)
746                 attrs[m.name] = value
747         return attrs
748
749     def _encode_struct(self, name, vals):
750         members = self.consts[name].members
751         attr_payload = b''
752         for m in members:
753             value = vals.pop(m.name) if m.name in vals else None
754             if m.type == 'pad':
755                 attr_payload += bytearray(m.len)
756             elif m.type == 'binary':
757                 if m.struct:
758                     if value is None:
759                         value = dict()
760                     attr_payload += self._encode_struct(m.struct, value)
761                 else:
762                     if value is None:
763                         attr_payload += bytearray(m.len)
764                     else:
765                         attr_payload += bytes.fromhex(value)
766             else:
767                 if value is None:
768                     value = 0
769                 format = NlAttr.get_format(m.type, m.byte_order)
770                 attr_payload += format.pack(value)
771         return attr_payload
772
773     def _formatted_string(self, raw, display_hint):
774         if display_hint == 'mac':
775             formatted = ':'.join('%02x' % b for b in raw)
776         elif display_hint == 'hex':
777             formatted = bytes.hex(raw, ' ')
778         elif display_hint in [ 'ipv4', 'ipv6' ]:
779             formatted = format(ipaddress.ip_address(raw))
780         elif display_hint == 'uuid':
781             formatted = str(uuid.UUID(bytes=raw))
782         else:
783             formatted = raw
784         return formatted
785
786     def handle_ntf(self, decoded):
787         msg = dict()
788         if self.include_raw:
789             msg['raw'] = decoded
790         op = self.rsp_by_value[decoded.cmd()]
791         attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
792         if op.fixed_header:
793             attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
794
795         msg['name'] = op['name']
796         msg['msg'] = attrs
797         self.async_msg_queue.append(msg)
798
799     def check_ntf(self):
800         while True:
801             try:
802                 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
803             except BlockingIOError:
804                 return
805
806             nms = NlMsgs(reply)
807             for nl_msg in nms:
808                 if nl_msg.error:
809                     print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
810                     print(nl_msg)
811                     continue
812                 if nl_msg.done:
813                     print("Netlink done while checking for ntf!?")
814                     continue
815
816                 decoded = self.nlproto.decode(self, nl_msg)
817                 if decoded.cmd() not in self.async_msg_ids:
818                     print("Unexpected msg id done while checking for ntf", decoded)
819                     continue
820
821                 self.handle_ntf(decoded)
822
823     def operation_do_attributes(self, name):
824       """
825       For a given operation name, find and return a supported
826       set of attributes (as a dict).
827       """
828       op = self.find_operation(name)
829       if not op:
830         return None
831
832       return op['do']['request']['attributes'].copy()
833
834     def _op(self, method, vals, flags=None, dump=False):
835         op = self.ops[method]
836
837         nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
838         for flag in flags or []:
839             nl_flags |= flag
840         if dump:
841             nl_flags |= Netlink.NLM_F_DUMP
842
843         req_seq = random.randint(1024, 65535)
844         msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
845         if op.fixed_header:
846             msg += self._encode_struct(op.fixed_header, vals)
847         search_attrs = SpaceAttrs(op.attr_set, vals)
848         for name, value in vals.items():
849             msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
850         msg = _genl_msg_finalize(msg)
851
852         self.sock.send(msg, 0)
853
854         done = False
855         rsp = []
856         while not done:
857             reply = self.sock.recv(128 * 1024)
858             nms = NlMsgs(reply, attr_space=op.attr_set)
859             for nl_msg in nms:
860                 if nl_msg.extack:
861                     self._decode_extack(msg, op, nl_msg.extack)
862
863                 if nl_msg.error:
864                     raise NlError(nl_msg)
865                 if nl_msg.done:
866                     if nl_msg.extack:
867                         print("Netlink warning:")
868                         print(nl_msg)
869                     done = True
870                     break
871
872                 decoded = self.nlproto.decode(self, nl_msg)
873
874                 # Check if this is a reply to our request
875                 if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
876                     if decoded.cmd() in self.async_msg_ids:
877                         self.handle_ntf(decoded)
878                         continue
879                     else:
880                         print('Unexpected message: ' + repr(decoded))
881                         continue
882
883                 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
884                 if op.fixed_header:
885                     rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
886                 rsp.append(rsp_msg)
887
888         if not rsp:
889             return None
890         if not dump and len(rsp) == 1:
891             return rsp[0]
892         return rsp
893
894     def do(self, method, vals, flags=None):
895         return self._op(method, vals, flags)
896
897     def dump(self, method, vals):
898         return self._op(method, vals, [], dump=True)