python:gkdi: Add Gkdi.from_key_envelope() method
[gd/samba/.git] / python / samba / gkdi.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Catalyst.Net Ltd 2023
3 #
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
17 #
18
19 """Group Key Distribution Service module"""
20
21 from enum import Enum
22 from functools import total_ordering
23 from typing import Optional, Tuple
24
25 from cryptography.hazmat.primitives import hashes
26
27 from samba import _glue
28 from samba.dcerpc import gkdi, misc
29 from samba.ndr import ndr_pack, ndr_unpack
30 from samba.nt_time import NtTime, NtTimeDelta
31
32
33 uint64_max: int = 2**64 - 1
34
35 L1_KEY_ITERATION: int = _glue.GKDI_L1_KEY_ITERATION
36 L2_KEY_ITERATION: int = _glue.GKDI_L2_KEY_ITERATION
37 KEY_CYCLE_DURATION: NtTimeDelta = _glue.GKDI_KEY_CYCLE_DURATION
38 MAX_CLOCK_SKEW: NtTimeDelta = _glue.GKDI_MAX_CLOCK_SKEW
39
40 KEY_LEN_BYTES = 64
41
42
43 class Algorithm(Enum):
44     SHA1 = "SHA1"
45     SHA256 = "SHA256"
46     SHA384 = "SHA384"
47     SHA512 = "SHA512"
48
49     def algorithm(self) -> hashes.HashAlgorithm:
50         if self is Algorithm.SHA1:
51             return hashes.SHA1()
52
53         if self is Algorithm.SHA256:
54             return hashes.SHA256()
55
56         if self is Algorithm.SHA384:
57             return hashes.SHA384()
58
59         if self is Algorithm.SHA512:
60             return hashes.SHA512()
61
62         raise RuntimeError("unknown hash algorithm {self}")
63
64     def __repr__(self) -> str:
65         return str(self)
66
67     @staticmethod
68     def from_kdf_parameters(kdf_param: Optional[bytes]) -> "Algorithm":
69         if not kdf_param:
70             return Algorithm.SHA256  # the default used by Windows.
71
72         kdf_parameters = ndr_unpack(gkdi.KdfParameters, kdf_param)
73         return Algorithm(kdf_parameters.hash_algorithm)
74
75
76 class GkidType(Enum):
77     DEFAULT = object()
78     L0_SEED_KEY = object()
79     L1_SEED_KEY = object()
80     L2_SEED_KEY = object()
81
82     def description(self) -> str:
83         if self is GkidType.DEFAULT:
84             return "a default GKID"
85
86         if self is GkidType.L0_SEED_KEY:
87             return "an L0 seed key"
88
89         if self is GkidType.L1_SEED_KEY:
90             return "an L1 seed key"
91
92         if self is GkidType.L2_SEED_KEY:
93             return "an L2 seed key"
94
95         raise RuntimeError("unknown GKID type {self}")
96
97
98 class InvalidDerivation(Exception):
99     pass
100
101
102 class UndefinedStartTime(Exception):
103     pass
104
105
106 @total_ordering
107 class Gkid:
108     # L2 increments every 10 hours. It rolls over after 320 hours (13 days and 8 hours).
109     # L1 increments every 320 hours. It rolls over after 10240 hours (426 days and 16 hours).
110     # L0 increments every 10240 hours. It rolls over after 43980465111040 hours (five billion years).
111
112     __slots__ = ["_l0_idx", "_l1_idx", "_l2_idx"]
113
114     max_l0_idx = 0x7FFF_FFFF
115
116     def __init__(self, l0_idx: int, l1_idx: int, l2_idx: int) -> None:
117         if not -1 <= l0_idx <= Gkid.max_l0_idx:
118             raise ValueError(f"L0 index {l0_idx} out of range")
119
120         if not -1 <= l1_idx < L1_KEY_ITERATION:
121             raise ValueError(f"L1 index {l1_idx} out of range")
122
123         if not -1 <= l2_idx < L2_KEY_ITERATION:
124             raise ValueError(f"L2 index {l2_idx} out of range")
125
126         if l0_idx == -1 and l1_idx != -1:
127             raise ValueError("invalid combination of negative and non‐negative indices")
128
129         if l1_idx == -1 and l2_idx != -1:
130             raise ValueError("invalid combination of negative and non‐negative indices")
131
132         self._l0_idx = l0_idx
133         self._l1_idx = l1_idx
134         self._l2_idx = l2_idx
135
136     @property
137     def l0_idx(self) -> int:
138         return self._l0_idx
139
140     @property
141     def l1_idx(self) -> int:
142         return self._l1_idx
143
144     @property
145     def l2_idx(self) -> int:
146         return self._l2_idx
147
148     def gkid_type(self) -> GkidType:
149         if self.l0_idx == -1:
150             return GkidType.DEFAULT
151
152         if self.l1_idx == -1:
153             return GkidType.L0_SEED_KEY
154
155         if self.l2_idx == -1:
156             return GkidType.L1_SEED_KEY
157
158         return GkidType.L2_SEED_KEY
159
160     def wrapped_l1_idx(self) -> int:
161         if self.l1_idx == -1:
162             return L1_KEY_ITERATION
163
164         return self.l1_idx
165
166     def wrapped_l2_idx(self) -> int:
167         if self.l2_idx == -1:
168             return L2_KEY_ITERATION
169
170         return self.l2_idx
171
172     def derive_l1_seed_key(self) -> "Gkid":
173         gkid_type = self.gkid_type()
174         if (
175             gkid_type is not GkidType.L0_SEED_KEY
176             and gkid_type is not GkidType.L1_SEED_KEY
177         ):
178             raise InvalidDerivation(
179                 "Invalid attempt to derive an L1 seed key from"
180                 f" {gkid_type.description()}"
181             )
182
183         if self.l1_idx == 0:
184             raise InvalidDerivation("No further derivation of L1 seed keys is possible")
185
186         return Gkid(self.l0_idx, self.wrapped_l1_idx() - 1, self.l2_idx)
187
188     def derive_l2_seed_key(self) -> "Gkid":
189         gkid_type = self.gkid_type()
190         if (
191             gkid_type is not GkidType.L1_SEED_KEY
192             and gkid_type is not GkidType.L2_SEED_KEY
193         ):
194             raise InvalidDerivation(
195                 f"Attempt to derive an L2 seed key from {gkid_type.description()}"
196             )
197
198         if self.l2_idx == 0:
199             raise InvalidDerivation("No further derivation of L2 seed keys is possible")
200
201         return Gkid(self.l0_idx, self.l1_idx, self.wrapped_l2_idx() - 1)
202
203     def __str__(self) -> str:
204         return f"Gkid({self.l0_idx}, {self.l1_idx}, {self.l2_idx})"
205
206     def __repr__(self) -> str:
207         cls = type(self)
208         return (
209             f"{cls.__qualname__}({repr(self.l0_idx)}, {repr(self.l1_idx)},"
210             f" {repr(self.l2_idx)})"
211         )
212
213     def __eq__(self, other: object) -> bool:
214         if not isinstance(other, Gkid):
215             return NotImplemented
216
217         return (self.l0_idx, self.l1_idx, self.l2_idx) == (
218             other.l0_idx,
219             other.l1_idx,
220             other.l2_idx,
221         )
222
223     def __lt__(self, other: object) -> bool:
224         if not isinstance(other, Gkid):
225             return NotImplemented
226
227         def as_tuple(gkid: Gkid) -> Tuple[int, int, int]:
228             l0_idx, l1_idx, l2_idx = gkid.l0_idx, gkid.l1_idx, gkid.l2_idx
229
230             # DEFAULT is considered less than everything else, so that the
231             # lexical ordering requirement in [MS-GKDI] 3.1.4.1.3 (GetKey) makes
232             # sense.
233             if gkid.gkid_type() is not GkidType.DEFAULT:
234                 # Use the wrapped indices so that L1 seed keys are considered
235                 # greater than their children L2 seed keys, and L0 seed keys are
236                 # considered greater than their children L1 seed keys.
237                 l1_idx = gkid.wrapped_l1_idx()
238                 l2_idx = gkid.wrapped_l2_idx()
239
240             return l0_idx, l1_idx, l2_idx
241
242         return as_tuple(self) < as_tuple(other)
243
244     def __hash__(self) -> int:
245         return hash((self.l0_idx, self.l1_idx, self.l2_idx))
246
247     @staticmethod
248     def default() -> "Gkid":
249         return Gkid(-1, -1, -1)
250
251     @staticmethod
252     def l0_seed_key(l0_idx: int) -> "Gkid":
253         return Gkid(l0_idx, -1, -1)
254
255     @staticmethod
256     def l1_seed_key(l0_idx: int, l1_idx: int) -> "Gkid":
257         return Gkid(l0_idx, l1_idx, -1)
258
259     @staticmethod
260     def from_nt_time(nt_time: NtTime) -> "Gkid":
261         l0 = nt_time // (L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION)
262         l1 = (
263             nt_time
264             % (L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION)
265             // (L2_KEY_ITERATION * KEY_CYCLE_DURATION)
266         )
267         l2 = nt_time % (L2_KEY_ITERATION * KEY_CYCLE_DURATION) // KEY_CYCLE_DURATION
268
269         return Gkid(l0, l1, l2)
270
271     def start_nt_time(self) -> NtTime:
272         gkid_type = self.gkid_type()
273         if gkid_type is not GkidType.L2_SEED_KEY:
274             raise UndefinedStartTime(
275                 f"{gkid_type.description()} has no defined start time"
276             )
277
278         start_time = NtTime(
279             (
280                 self.l0_idx * L1_KEY_ITERATION * L2_KEY_ITERATION
281                 + self.l1_idx * L2_KEY_ITERATION
282                 + self.l2_idx
283             )
284             * KEY_CYCLE_DURATION
285         )
286
287         if not 0 <= start_time <= uint64_max:
288             raise OverflowError(f"start time {start_time} out of range")
289
290         return start_time
291
292     @staticmethod
293     def from_key_envelope(env: gkdi.KeyEnvelope) -> "Gkid":
294         return Gkid(env.l0_index, env.l1_index, env.l2_index)
295
296
297 class SeedKeyPair:
298     __slots__ = ["l1_key", "l2_key", "gkid", "hash_algorithm", "root_key_id"]
299
300     def __init__(
301         self,
302         l1_key: Optional[bytes],
303         l2_key: Optional[bytes],
304         gkid: Gkid,
305         hash_algorithm: Algorithm,
306         root_key_id: misc.GUID,
307     ) -> None:
308         if l1_key is not None and len(l1_key) != KEY_LEN_BYTES:
309             raise ValueError(f"L1 key ({repr(l1_key)}) must be {KEY_LEN_BYTES} bytes")
310         if l2_key is not None and len(l2_key) != KEY_LEN_BYTES:
311             raise ValueError(f"L2 key ({repr(l2_key)}) must be {KEY_LEN_BYTES} bytes")
312
313         self.l1_key = l1_key
314         self.l2_key = l2_key
315         self.gkid = gkid
316         self.hash_algorithm = hash_algorithm
317         self.root_key_id = root_key_id
318
319     def __str__(self) -> str:
320         l1_key_hex = None if self.l1_key is None else self.l1_key.hex()
321         l2_key_hex = None if self.l2_key is None else self.l2_key.hex()
322
323         return (
324             f"SeedKeyPair(L1Key({l1_key_hex}), L2Key({l2_key_hex}), {self.gkid},"
325             f" {self.root_key_id}, {self.hash_algorithm})"
326         )
327
328     def __repr__(self) -> str:
329         cls = type(self)
330         return (
331             f"{cls.__qualname__}({repr(self.l1_key)}, {repr(self.l2_key)},"
332             f" {repr(self.gkid)}, {repr(self.hash_algorithm)},"
333             f" {repr(self.root_key_id)})"
334         )
335
336     def __eq__(self, other: object) -> bool:
337         if not isinstance(other, SeedKeyPair):
338             return NotImplemented
339
340         return (
341             self.l1_key,
342             self.l2_key,
343             self.gkid,
344             self.hash_algorithm,
345             self.root_key_id,
346         ) == (
347             other.l1_key,
348             other.l2_key,
349             other.gkid,
350             other.hash_algorithm,
351             other.root_key_id,
352         )
353
354     def __hash__(self) -> int:
355         return hash((
356             self.l1_key,
357             self.l2_key,
358             self.gkid,
359             self.hash_algorithm,
360             ndr_pack(self.root_key_id),
361         ))
362
363
364 class GroupKey:
365     __slots__ = ["gkid", "key", "hash_algorithm", "root_key_id"]
366
367     def __init__(
368         self, key: bytes, gkid: Gkid, hash_algorithm: Algorithm, root_key_id: misc.GUID
369     ) -> None:
370         if key is not None and len(key) != KEY_LEN_BYTES:
371             raise ValueError(f"Key ({repr(key)}) must be {KEY_LEN_BYTES} bytes")
372
373         self.key = key
374         self.gkid = gkid
375         self.hash_algorithm = hash_algorithm
376         self.root_key_id = root_key_id
377
378     def __str__(self) -> str:
379         return (
380             f"GroupKey(Key({self.key.hex()}), {self.gkid}, {self.hash_algorithm},"
381             f" {self.root_key_id})"
382         )
383
384     def __repr__(self) -> str:
385         cls = type(self)
386         return (
387             f"{cls.__qualname__}({repr(self.key)}, {repr(self.gkid)},"
388             f" {repr(self.hash_algorithm)}, {repr(self.root_key_id)})"
389         )
390
391     def __eq__(self, other: object) -> bool:
392         if not isinstance(other, GroupKey):
393             return NotImplemented
394
395         return (self.key, self.gkid, self.hash_algorithm, self.root_key_id) == (
396             other.key,
397             other.gkid,
398             other.hash_algorithm,
399             other.root_key_id,
400         )
401
402     def __hash__(self) -> int:
403         return hash(
404             (self.key, self.gkid, self.hash_algorithm, ndr_pack(self.root_key_id))
405         )