netfilter: nft_set_pipapo: release elements in clone only from destroy path
[sfrench/cifs-2.6.git] / net / netfilter / nft_set_pipapo.c
index aa1d9e93a9a04859d48e417501c7f9e889187400..df8de50902463738642d4d24b59f12b17b5ff726 100644 (file)
  *
  * Return: -1 on no match, bit position on 'match_only', 0 otherwise.
  */
-int pipapo_refill(unsigned long *map, int len, int rules, unsigned long *dst,
-                 union nft_pipapo_map_bucket *mt, bool match_only)
+int pipapo_refill(unsigned long *map, unsigned int len, unsigned int rules,
+                 unsigned long *dst,
+                 const union nft_pipapo_map_bucket *mt, bool match_only)
 {
        unsigned long bitset;
-       int k, ret = -1;
+       unsigned int k;
+       int ret = -1;
 
        for (k = 0; k < len; k++) {
                bitset = map[k];
@@ -412,9 +414,9 @@ bool nft_pipapo_lookup(const struct net *net, const struct nft_set *set,
        struct nft_pipapo_scratch *scratch;
        unsigned long *res_map, *fill_map;
        u8 genmask = nft_genmask_cur(net);
+       const struct nft_pipapo_match *m;
+       const struct nft_pipapo_field *f;
        const u8 *rp = (const u8 *)key;
-       struct nft_pipapo_match *m;
-       struct nft_pipapo_field *f;
        bool map_index;
        int i;
 
@@ -505,6 +507,7 @@ out:
  * @data:      Key data to be matched against existing elements
  * @genmask:   If set, check that element is active in given genmask
  * @tstamp:    timestamp to check for expired elements
+ * @gfp:       the type of memory to allocate (see kmalloc).
  *
  * This is essentially the same as the lookup function, except that it matches
  * key data against the uncommitted copy and doesn't use preallocated maps for
@@ -515,22 +518,26 @@ out:
 static struct nft_pipapo_elem *pipapo_get(const struct net *net,
                                          const struct nft_set *set,
                                          const u8 *data, u8 genmask,
-                                         u64 tstamp)
+                                         u64 tstamp, gfp_t gfp)
 {
        struct nft_pipapo_elem *ret = ERR_PTR(-ENOENT);
        struct nft_pipapo *priv = nft_set_priv(set);
-       struct nft_pipapo_match *m = priv->clone;
        unsigned long *res_map, *fill_map = NULL;
-       struct nft_pipapo_field *f;
+       const struct nft_pipapo_match *m;
+       const struct nft_pipapo_field *f;
        int i;
 
-       res_map = kmalloc_array(m->bsize_max, sizeof(*res_map), GFP_ATOMIC);
+       m = priv->clone;
+       if (m->bsize_max == 0)
+               return ret;
+
+       res_map = kmalloc_array(m->bsize_max, sizeof(*res_map), gfp);
        if (!res_map) {
                ret = ERR_PTR(-ENOMEM);
                goto out;
        }
 
-       fill_map = kcalloc(m->bsize_max, sizeof(*res_map), GFP_ATOMIC);
+       fill_map = kcalloc(m->bsize_max, sizeof(*res_map), gfp);
        if (!fill_map) {
                ret = ERR_PTR(-ENOMEM);
                goto out;
@@ -608,13 +615,73 @@ nft_pipapo_get(const struct net *net, const struct nft_set *set,
        struct nft_pipapo_elem *e;
 
        e = pipapo_get(net, set, (const u8 *)elem->key.val.data,
-                      nft_genmask_cur(net), get_jiffies_64());
+                      nft_genmask_cur(net), get_jiffies_64(),
+                      GFP_ATOMIC);
        if (IS_ERR(e))
                return ERR_CAST(e);
 
        return &e->priv;
 }
 
+/**
+ * pipapo_realloc_mt() - Reallocate mapping table if needed upon resize
+ * @f:         Field containing mapping table
+ * @old_rules: Amount of existing mapped rules
+ * @rules:     Amount of new rules to map
+ *
+ * Return: 0 on success, negative error code on failure.
+ */
+static int pipapo_realloc_mt(struct nft_pipapo_field *f,
+                            unsigned int old_rules, unsigned int rules)
+{
+       union nft_pipapo_map_bucket *new_mt = NULL, *old_mt = f->mt;
+       const unsigned int extra = PAGE_SIZE / sizeof(*new_mt);
+       unsigned int rules_alloc = rules;
+
+       might_sleep();
+
+       if (unlikely(rules == 0))
+               goto out_free;
+
+       /* growing and enough space left, no action needed */
+       if (rules > old_rules && f->rules_alloc > rules)
+               return 0;
+
+       /* downsize and extra slack has not grown too large */
+       if (rules < old_rules) {
+               unsigned int remove = f->rules_alloc - rules;
+
+               if (remove < (2u * extra))
+                       return 0;
+       }
+
+       /* If set needs more than one page of memory for rules then
+        * allocate another extra page to avoid frequent reallocation.
+        */
+       if (rules > extra &&
+           check_add_overflow(rules, extra, &rules_alloc))
+               return -EOVERFLOW;
+
+       new_mt = kvmalloc_array(rules_alloc, sizeof(*new_mt), GFP_KERNEL);
+       if (!new_mt)
+               return -ENOMEM;
+
+       if (old_mt)
+               memcpy(new_mt, old_mt, min(old_rules, rules) * sizeof(*new_mt));
+
+       if (rules > old_rules) {
+               memset(new_mt + old_rules, 0,
+                      (rules - old_rules) * sizeof(*new_mt));
+       }
+out_free:
+       f->rules_alloc = rules_alloc;
+       f->mt = new_mt;
+
+       kvfree(old_mt);
+
+       return 0;
+}
+
 /**
  * pipapo_resize() - Resize lookup or mapping table, or both
  * @f:         Field containing lookup and mapping tables
@@ -627,12 +694,15 @@ nft_pipapo_get(const struct net *net, const struct nft_set *set,
  *
  * Return: 0 on success, -ENOMEM on allocation failure.
  */
-static int pipapo_resize(struct nft_pipapo_field *f, int old_rules, int rules)
+static int pipapo_resize(struct nft_pipapo_field *f,
+                        unsigned int old_rules, unsigned int rules)
 {
        long *new_lt = NULL, *new_p, *old_lt = f->lt, *old_p;
-       union nft_pipapo_map_bucket *new_mt, *old_mt = f->mt;
-       size_t new_bucket_size, copy;
-       int group, bucket;
+       unsigned int new_bucket_size, copy;
+       int group, bucket, err;
+
+       if (rules >= NFT_PIPAPO_RULE0_MAX)
+               return -ENOSPC;
 
        new_bucket_size = DIV_ROUND_UP(rules, BITS_PER_LONG);
 #ifdef NFT_PIPAPO_ALIGN
@@ -672,27 +742,18 @@ static int pipapo_resize(struct nft_pipapo_field *f, int old_rules, int rules)
        }
 
 mt:
-       new_mt = kvmalloc(rules * sizeof(*new_mt), GFP_KERNEL);
-       if (!new_mt) {
+       err = pipapo_realloc_mt(f, old_rules, rules);
+       if (err) {
                kvfree(new_lt);
-               return -ENOMEM;
-       }
-
-       memcpy(new_mt, f->mt, min(old_rules, rules) * sizeof(*new_mt));
-       if (rules > old_rules) {
-               memset(new_mt + old_rules, 0,
-                      (rules - old_rules) * sizeof(*new_mt));
+               return err;
        }
 
        if (new_lt) {
                f->bsize = new_bucket_size;
-               NFT_PIPAPO_LT_ASSIGN(f, new_lt);
+               f->lt = new_lt;
                kvfree(old_lt);
        }
 
-       f->mt = new_mt;
-       kvfree(old_mt);
-
        return 0;
 }
 
@@ -843,8 +904,8 @@ static void pipapo_lt_8b_to_4b(int old_groups, int bsize,
  */
 static void pipapo_lt_bits_adjust(struct nft_pipapo_field *f)
 {
+       unsigned int groups, bb;
        unsigned long *new_lt;
-       int groups, bb;
        size_t lt_size;
 
        lt_size = f->groups * NFT_PIPAPO_BUCKETS(f->bb) * f->bsize *
@@ -894,7 +955,7 @@ static void pipapo_lt_bits_adjust(struct nft_pipapo_field *f)
        f->groups = groups;
        f->bb = bb;
        kvfree(f->lt);
-       NFT_PIPAPO_LT_ASSIGN(f, new_lt);
+       f->lt = new_lt;
 }
 
 /**
@@ -911,7 +972,7 @@ static void pipapo_lt_bits_adjust(struct nft_pipapo_field *f)
 static int pipapo_insert(struct nft_pipapo_field *f, const uint8_t *k,
                         int mask_bits)
 {
-       int rule = f->rules, group, ret, bit_offset = 0;
+       unsigned int rule = f->rules, group, ret, bit_offset = 0;
 
        ret = pipapo_resize(f, f->rules, f->rules + 1);
        if (ret)
@@ -1216,7 +1277,7 @@ static int nft_pipapo_insert(const struct net *net, const struct nft_set *set,
        else
                end = start;
 
-       dup = pipapo_get(net, set, start, genmask, tstamp);
+       dup = pipapo_get(net, set, start, genmask, tstamp, GFP_KERNEL);
        if (!IS_ERR(dup)) {
                /* Check if we already have the same exact entry */
                const struct nft_data *dup_key, *dup_end;
@@ -1238,7 +1299,8 @@ static int nft_pipapo_insert(const struct net *net, const struct nft_set *set,
 
        if (PTR_ERR(dup) == -ENOENT) {
                /* Look for partially overlapping entries */
-               dup = pipapo_get(net, set, end, nft_genmask_next(net), tstamp);
+               dup = pipapo_get(net, set, end, nft_genmask_next(net), tstamp,
+                                GFP_KERNEL);
        }
 
        if (PTR_ERR(dup) != -ENOENT) {
@@ -1251,8 +1313,14 @@ static int nft_pipapo_insert(const struct net *net, const struct nft_set *set,
        /* Validate */
        start_p = start;
        end_p = end;
+
+       /* some helpers return -1, or 0 >= for valid rule pos,
+        * so we cannot support more than INT_MAX rules at this time.
+        */
+       BUILD_BUG_ON(NFT_PIPAPO_RULE0_MAX > INT_MAX);
+
        nft_pipapo_for_each_field(f, i, m) {
-               if (f->rules >= (unsigned long)NFT_PIPAPO_RULE0_MAX)
+               if (f->rules >= NFT_PIPAPO_RULE0_MAX)
                        return -ENOSPC;
 
                if (memcmp(start_p, end_p,
@@ -1358,18 +1426,25 @@ static struct nft_pipapo_match *pipapo_clone(struct nft_pipapo_match *old)
                if (!new_lt)
                        goto out_lt;
 
-               NFT_PIPAPO_LT_ASSIGN(dst, new_lt);
+               dst->lt = new_lt;
 
                memcpy(NFT_PIPAPO_LT_ALIGN(new_lt),
                       NFT_PIPAPO_LT_ALIGN(src->lt),
                       src->bsize * sizeof(*dst->lt) *
                       src->groups * NFT_PIPAPO_BUCKETS(src->bb));
 
-               dst->mt = kvmalloc(src->rules * sizeof(*src->mt), GFP_KERNEL);
-               if (!dst->mt)
-                       goto out_mt;
+               if (src->rules > 0) {
+                       dst->mt = kvmalloc_array(src->rules_alloc,
+                                                sizeof(*src->mt), GFP_KERNEL);
+                       if (!dst->mt)
+                               goto out_mt;
+
+                       memcpy(dst->mt, src->mt, src->rules * sizeof(*src->mt));
+               } else {
+                       dst->mt = NULL;
+                       dst->rules_alloc = 0;
+               }
 
-               memcpy(dst->mt, src->mt, src->rules * sizeof(*src->mt));
                src++;
                dst++;
        }
@@ -1423,10 +1498,10 @@ out_scratch:
  *
  * Return: Number of rules that originated from the same entry as @first.
  */
-static int pipapo_rules_same_key(struct nft_pipapo_field *f, int first)
+static unsigned int pipapo_rules_same_key(struct nft_pipapo_field *f, unsigned int first)
 {
        struct nft_pipapo_elem *e = NULL; /* Keep gcc happy */
-       int r;
+       unsigned int r;
 
        for (r = first; r < f->rules; r++) {
                if (r != first && e != f->mt[r].e)
@@ -1479,8 +1554,9 @@ static int pipapo_rules_same_key(struct nft_pipapo_field *f, int first)
  *                        0      1      2
  *  element pointers:  0x42   0x42   0x44
  */
-static void pipapo_unmap(union nft_pipapo_map_bucket *mt, int rules,
-                        int start, int n, int to_offset, bool is_last)
+static void pipapo_unmap(union nft_pipapo_map_bucket *mt, unsigned int rules,
+                        unsigned int start, unsigned int n,
+                        unsigned int to_offset, bool is_last)
 {
        int i;
 
@@ -1586,8 +1662,8 @@ static void pipapo_gc(struct nft_set *set, struct nft_pipapo_match *m)
 {
        struct nft_pipapo *priv = nft_set_priv(set);
        struct net *net = read_pnet(&set->net);
+       unsigned int rules_f0, first_rule = 0;
        u64 tstamp = nft_net_tstamp(net);
-       int rules_f0, first_rule = 0;
        struct nft_pipapo_elem *e;
        struct nft_trans_gc *gc;
 
@@ -1597,8 +1673,8 @@ static void pipapo_gc(struct nft_set *set, struct nft_pipapo_match *m)
 
        while ((rules_f0 = pipapo_rules_same_key(m->f, first_rule))) {
                union nft_pipapo_map_bucket rulemap[NFT_PIPAPO_MAX_FIELDS];
-               struct nft_pipapo_field *f;
-               int i, start, rules_fx;
+               const struct nft_pipapo_field *f;
+               unsigned int i, start, rules_fx;
 
                start = first_rule;
                rules_fx = rules_f0;
@@ -1792,7 +1868,8 @@ static void *pipapo_deactivate(const struct net *net, const struct nft_set *set,
 {
        struct nft_pipapo_elem *e;
 
-       e = pipapo_get(net, set, data, nft_genmask_next(net), nft_net_tstamp(net));
+       e = pipapo_get(net, set, data, nft_genmask_next(net),
+                      nft_net_tstamp(net), GFP_KERNEL);
        if (IS_ERR(e))
                return NULL;
 
@@ -1976,7 +2053,7 @@ static void nft_pipapo_remove(const struct net *net, const struct nft_set *set,
 {
        struct nft_pipapo *priv = nft_set_priv(set);
        struct nft_pipapo_match *m = priv->clone;
-       int rules_f0, first_rule = 0;
+       unsigned int rules_f0, first_rule = 0;
        struct nft_pipapo_elem *e;
        const u8 *data;
 
@@ -2039,9 +2116,9 @@ static void nft_pipapo_walk(const struct nft_ctx *ctx, struct nft_set *set,
 {
        struct nft_pipapo *priv = nft_set_priv(set);
        struct net *net = read_pnet(&set->net);
-       struct nft_pipapo_match *m;
-       struct nft_pipapo_field *f;
-       int i, r;
+       const struct nft_pipapo_match *m;
+       const struct nft_pipapo_field *f;
+       unsigned int i, r;
 
        rcu_read_lock();
        if (iter->genmask == nft_genmask_cur(net))
@@ -2145,6 +2222,9 @@ static int nft_pipapo_init(const struct nft_set *set,
 
        field_count = desc->field_count ? : 1;
 
+       BUILD_BUG_ON(NFT_PIPAPO_MAX_FIELDS > 255);
+       BUILD_BUG_ON(NFT_PIPAPO_MAX_FIELDS != NFT_REG32_COUNT);
+
        if (field_count > NFT_PIPAPO_MAX_FIELDS)
                return -EINVAL;
 
@@ -2166,7 +2246,11 @@ static int nft_pipapo_init(const struct nft_set *set,
        rcu_head_init(&m->rcu);
 
        nft_pipapo_for_each_field(f, i, m) {
-               int len = desc->field_len[i] ? : set->klen;
+               unsigned int len = desc->field_len[i] ? : set->klen;
+
+               /* f->groups is u8 */
+               BUILD_BUG_ON((NFT_PIPAPO_MAX_BYTES *
+                             BITS_PER_BYTE / NFT_PIPAPO_GROUP_BITS_LARGE_SET) >= 256);
 
                f->bb = NFT_PIPAPO_GROUP_BITS_INIT;
                f->groups = len * NFT_PIPAPO_GROUPS_PER_BYTE(f);
@@ -2175,7 +2259,8 @@ static int nft_pipapo_init(const struct nft_set *set,
 
                f->bsize = 0;
                f->rules = 0;
-               NFT_PIPAPO_LT_ASSIGN(f, NULL);
+               f->rules_alloc = 0;
+               f->lt = NULL;
                f->mt = NULL;
        }
 
@@ -2211,7 +2296,7 @@ static void nft_set_pipapo_match_destroy(const struct nft_ctx *ctx,
                                         struct nft_pipapo_match *m)
 {
        struct nft_pipapo_field *f;
-       int i, r;
+       unsigned int i, r;
 
        for (i = 0, f = m->f; i < m->field_count - 1; i++, f++)
                ;
@@ -2244,8 +2329,6 @@ static void nft_pipapo_destroy(const struct nft_ctx *ctx,
        if (m) {
                rcu_barrier();
 
-               nft_set_pipapo_match_destroy(ctx, set, m);
-
                for_each_possible_cpu(cpu)
                        pipapo_free_scratch(m, cpu);
                free_percpu(m->scratch);
@@ -2257,8 +2340,7 @@ static void nft_pipapo_destroy(const struct nft_ctx *ctx,
        if (priv->clone) {
                m = priv->clone;
 
-               if (priv->dirty)
-                       nft_set_pipapo_match_destroy(ctx, set, m);
+               nft_set_pipapo_match_destroy(ctx, set, m);
 
                for_each_possible_cpu(cpu)
                        pipapo_free_scratch(priv->clone, cpu);