lib:crypto: Add more unit tests for GKDI functions
authorJo Sutton <josutton@catalyst.net.nz>
Mon, 15 Apr 2024 00:19:12 +0000 (12:19 +1200)
committerAndrew Bartlett <abartlet@samba.org>
Fri, 19 Apr 2024 05:02:54 +0000 (17:02 +1200)
Signed-off-by: Jo Sutton <josutton@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
lib/crypto/test_gkdi.c

index e6d3b28ae583c0758a100e004ba8df000f4a352f..083d71eefd3916df78d43ed1da3202e450ebd982 100644 (file)
@@ -136,10 +136,193 @@ static void test_password_based_on_key_id(void **state)
        talloc_free(mem_ctx);
 }
 
+static void assert_gkid_equal(const struct Gkid g1, const struct Gkid g2)
+{
+       assert_int_equal(g1.l0_idx, g2.l0_idx);
+       assert_int_equal(g1.l1_idx, g2.l1_idx);
+       assert_int_equal(g1.l2_idx, g2.l2_idx);
+}
+
+static void test_gkdi_rollover_interval(void **state)
+{
+       NTTIME interval;
+       bool ok;
+
+       ok = gkdi_rollover_interval(0, &interval);
+       assert_true(ok);
+       assert_int_equal(0, interval);
+
+       ok = gkdi_rollover_interval(1, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(720000000000), interval);
+
+       ok = gkdi_rollover_interval(2, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(1440000000000), interval);
+
+       ok = gkdi_rollover_interval(3, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(2520000000000), interval);
+
+       ok = gkdi_rollover_interval(4, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(3240000000000), interval);
+
+       ok = gkdi_rollover_interval(5, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(4320000000000), interval);
+
+       ok = gkdi_rollover_interval(-1, &interval);
+       assert_false(ok);
+
+       ok = gkdi_rollover_interval(-2, &interval);
+       assert_false(ok);
+
+       ok = gkdi_rollover_interval(10675199, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(9223371720000000000), interval);
+
+       ok = gkdi_rollover_interval(-10675198, &interval);
+       assert_false(ok);
+
+       ok = gkdi_rollover_interval(10675200, &interval);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(9223372800000000000), interval);
+
+       ok = gkdi_rollover_interval(-10675199, &interval);
+       assert_false(ok);
+
+       ok = gkdi_rollover_interval(21350398, &interval);
+       /*
+        * If we accepted this high of an interval, the result would be
+        * 18446743800000000000.
+        */
+       assert_false(ok);
+
+       ok = gkdi_rollover_interval(-21350397, &interval);
+       assert_false(ok);
+
+       ok = gkdi_rollover_interval(21350399, &interval);
+       assert_false(ok); /* too large to be represented */
+
+       ok = gkdi_rollover_interval(-21350398, &interval);
+       assert_false(ok); /* too small to be represented */
+}
+
+static void assert_get_interval_id(const NTTIME time,
+                                  const struct Gkid expected_gkid)
+{
+       {
+               const bool valid = gkid_is_valid(expected_gkid);
+               assert_true(valid);
+       }
+
+       {
+               const struct Gkid interval_id = gkdi_get_interval_id(time);
+               assert_gkid_equal(expected_gkid, interval_id);
+       }
+}
+
+static void test_get_interval_id(void **state)
+{
+       assert_get_interval_id(0, Gkid(0, 0, 0));
+
+       assert_get_interval_id(gkdi_key_cycle_duration - 1, Gkid(0, 0, 0));
+
+       assert_get_interval_id(gkdi_key_cycle_duration, Gkid(0, 0, 1));
+
+       assert_get_interval_id(27 * gkdi_key_cycle_duration, Gkid(0, 0, 27));
+
+       assert_get_interval_id((gkdi_l2_key_iteration - 1) *
+                                      gkdi_key_cycle_duration,
+                              Gkid(0, 0, gkdi_l2_key_iteration - 1));
+
+       assert_get_interval_id(gkdi_l2_key_iteration * gkdi_key_cycle_duration,
+                              Gkid(0, 1, 0));
+
+       assert_get_interval_id(17 * gkdi_l2_key_iteration *
+                                      gkdi_key_cycle_duration,
+                              Gkid(0, 17, 0));
+
+       assert_get_interval_id(((gkdi_l1_key_iteration - 1) *
+                                       gkdi_l2_key_iteration +
+                               3) * gkdi_key_cycle_duration,
+                              Gkid(0, gkdi_l1_key_iteration - 1, 3));
+
+       assert_get_interval_id(gkdi_l1_key_iteration * gkdi_l2_key_iteration *
+                                      gkdi_key_cycle_duration,
+                              Gkid(1, 0, 0));
+
+       assert_get_interval_id(((1234 * gkdi_l1_key_iteration + 8) *
+                                       gkdi_l2_key_iteration +
+                               13) * gkdi_key_cycle_duration,
+                              Gkid(1234, 8, 13));
+
+       assert_get_interval_id(INT64_MAX, Gkid(25019, 31, 29));
+
+       assert_get_interval_id(UINT64_MAX, Gkid(50039, 31, 27));
+}
+
+static void test_get_key_start_time(void **state)
+{
+       NTTIME start_time = 0;
+       bool ok;
+
+       /* Try passing an invalid GKID. */
+       ok = gkdi_get_key_start_time(invalid_gkid, &start_time);
+       assert_false(ok);
+
+       /* Try passing an L1 GKID rather than an L2 GKID. */
+       ok = gkdi_get_key_start_time(Gkid(0, 0, -1), &start_time);
+       assert_false(ok);
+
+       /* Test some L2 GKIDs. */
+
+       ok = gkdi_get_key_start_time(Gkid(0, 0, 0), &start_time);
+       assert_true(ok);
+       assert_int_equal(0, start_time);
+
+       ok = gkdi_get_key_start_time(Gkid(0, 0, 1), &start_time);
+       assert_true(ok);
+       assert_int_equal(gkdi_key_cycle_duration, start_time);
+
+       ok = gkdi_get_key_start_time(Gkid(123, 18, 2), &start_time);
+       assert_true(ok);
+       assert_int_equal(126530 * gkdi_key_cycle_duration, start_time);
+
+       ok = gkdi_get_key_start_time(Gkid(25019, 31, 29), &start_time);
+       assert_true(ok);
+       assert_int_equal(25620477 * gkdi_key_cycle_duration, start_time);
+
+       ok = gkdi_get_key_start_time(Gkid(25019, 31, 30), &start_time);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(25620478) * gkdi_key_cycle_duration,
+                        start_time);
+
+       ok = gkdi_get_key_start_time(Gkid(50039, 31, 27), &start_time);
+       assert_true(ok);
+       assert_int_equal(UINT64_C(51240955) * gkdi_key_cycle_duration,
+                        start_time);
+
+       /*
+        * Test GKIDs so high that their start times can’t be represented in
+        * NTTIME.
+        */
+
+       ok = gkdi_get_key_start_time(Gkid(50039, 31, 28), &start_time);
+       assert_false(ok);
+
+       ok = gkdi_get_key_start_time(Gkid(INT32_MAX, 31, 31), &start_time);
+       assert_false(ok);
+}
+
 int main(int argc, char *argv[])
 {
        const struct CMUnitTest tests[] = {
                cmocka_unit_test(test_password_based_on_key_id),
+               cmocka_unit_test(test_gkdi_rollover_interval),
+               cmocka_unit_test(test_get_interval_id),
+               cmocka_unit_test(test_get_key_start_time),
        };
 
        if (argc == 2) {