fs: smb: common: add missing MODULE_DESCRIPTION() macros
[sfrench/cifs-2.6.git] / arch / riscv / net / bpf_jit_comp64.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include <asm/patch.h>
14 #include <asm/cfi.h>
15 #include "bpf_jit.h"
16
17 #define RV_FENTRY_NINSNS 2
18
19 #define RV_REG_TCC RV_REG_A6
20 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
21
22 static const int regmap[] = {
23         [BPF_REG_0] =   RV_REG_A5,
24         [BPF_REG_1] =   RV_REG_A0,
25         [BPF_REG_2] =   RV_REG_A1,
26         [BPF_REG_3] =   RV_REG_A2,
27         [BPF_REG_4] =   RV_REG_A3,
28         [BPF_REG_5] =   RV_REG_A4,
29         [BPF_REG_6] =   RV_REG_S1,
30         [BPF_REG_7] =   RV_REG_S2,
31         [BPF_REG_8] =   RV_REG_S3,
32         [BPF_REG_9] =   RV_REG_S4,
33         [BPF_REG_FP] =  RV_REG_S5,
34         [BPF_REG_AX] =  RV_REG_T0,
35 };
36
37 static const int pt_regmap[] = {
38         [RV_REG_A0] = offsetof(struct pt_regs, a0),
39         [RV_REG_A1] = offsetof(struct pt_regs, a1),
40         [RV_REG_A2] = offsetof(struct pt_regs, a2),
41         [RV_REG_A3] = offsetof(struct pt_regs, a3),
42         [RV_REG_A4] = offsetof(struct pt_regs, a4),
43         [RV_REG_A5] = offsetof(struct pt_regs, a5),
44         [RV_REG_S1] = offsetof(struct pt_regs, s1),
45         [RV_REG_S2] = offsetof(struct pt_regs, s2),
46         [RV_REG_S3] = offsetof(struct pt_regs, s3),
47         [RV_REG_S4] = offsetof(struct pt_regs, s4),
48         [RV_REG_S5] = offsetof(struct pt_regs, s5),
49         [RV_REG_T0] = offsetof(struct pt_regs, t0),
50 };
51
52 enum {
53         RV_CTX_F_SEEN_TAIL_CALL =       0,
54         RV_CTX_F_SEEN_CALL =            RV_REG_RA,
55         RV_CTX_F_SEEN_S1 =              RV_REG_S1,
56         RV_CTX_F_SEEN_S2 =              RV_REG_S2,
57         RV_CTX_F_SEEN_S3 =              RV_REG_S3,
58         RV_CTX_F_SEEN_S4 =              RV_REG_S4,
59         RV_CTX_F_SEEN_S5 =              RV_REG_S5,
60         RV_CTX_F_SEEN_S6 =              RV_REG_S6,
61 };
62
63 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
64 {
65         u8 reg = regmap[bpf_reg];
66
67         switch (reg) {
68         case RV_CTX_F_SEEN_S1:
69         case RV_CTX_F_SEEN_S2:
70         case RV_CTX_F_SEEN_S3:
71         case RV_CTX_F_SEEN_S4:
72         case RV_CTX_F_SEEN_S5:
73         case RV_CTX_F_SEEN_S6:
74                 __set_bit(reg, &ctx->flags);
75         }
76         return reg;
77 };
78
79 static bool seen_reg(int reg, struct rv_jit_context *ctx)
80 {
81         switch (reg) {
82         case RV_CTX_F_SEEN_CALL:
83         case RV_CTX_F_SEEN_S1:
84         case RV_CTX_F_SEEN_S2:
85         case RV_CTX_F_SEEN_S3:
86         case RV_CTX_F_SEEN_S4:
87         case RV_CTX_F_SEEN_S5:
88         case RV_CTX_F_SEEN_S6:
89                 return test_bit(reg, &ctx->flags);
90         }
91         return false;
92 }
93
94 static void mark_fp(struct rv_jit_context *ctx)
95 {
96         __set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
97 }
98
99 static void mark_call(struct rv_jit_context *ctx)
100 {
101         __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
102 }
103
104 static bool seen_call(struct rv_jit_context *ctx)
105 {
106         return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
107 }
108
109 static void mark_tail_call(struct rv_jit_context *ctx)
110 {
111         __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
112 }
113
114 static bool seen_tail_call(struct rv_jit_context *ctx)
115 {
116         return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
117 }
118
119 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
120 {
121         mark_tail_call(ctx);
122
123         if (seen_call(ctx)) {
124                 __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
125                 return RV_REG_S6;
126         }
127         return RV_REG_A6;
128 }
129
130 static bool is_32b_int(s64 val)
131 {
132         return -(1L << 31) <= val && val < (1L << 31);
133 }
134
135 static bool in_auipc_jalr_range(s64 val)
136 {
137         /*
138          * auipc+jalr can reach any signed PC-relative offset in the range
139          * [-2^31 - 2^11, 2^31 - 2^11).
140          */
141         return (-(1L << 31) - (1L << 11)) <= val &&
142                 val < ((1L << 31) - (1L << 11));
143 }
144
145 /* Modify rd pointer to alternate reg to avoid corrupting original reg */
146 static void emit_sextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx)
147 {
148         emit_sextw(ra, *rd, ctx);
149         *rd = ra;
150 }
151
152 static void emit_zextw_alt(u8 *rd, u8 ra, struct rv_jit_context *ctx)
153 {
154         emit_zextw(ra, *rd, ctx);
155         *rd = ra;
156 }
157
158 /* Emit fixed-length instructions for address */
159 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
160 {
161         /*
162          * Use the ro_insns(RX) to calculate the offset as the BPF program will
163          * finally run from this memory region.
164          */
165         u64 ip = (u64)(ctx->ro_insns + ctx->ninsns);
166         s64 off = addr - ip;
167         s64 upper = (off + (1 << 11)) >> 12;
168         s64 lower = off & 0xfff;
169
170         if (extra_pass && !in_auipc_jalr_range(off)) {
171                 pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
172                 return -ERANGE;
173         }
174
175         emit(rv_auipc(rd, upper), ctx);
176         emit(rv_addi(rd, rd, lower), ctx);
177         return 0;
178 }
179
180 /* Emit variable-length instructions for 32-bit and 64-bit imm */
181 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
182 {
183         /* Note that the immediate from the add is sign-extended,
184          * which means that we need to compensate this by adding 2^12,
185          * when the 12th bit is set. A simpler way of doing this, and
186          * getting rid of the check, is to just add 2**11 before the
187          * shift. The "Loading a 32-Bit constant" example from the
188          * "Computer Organization and Design, RISC-V edition" book by
189          * Patterson/Hennessy highlights this fact.
190          *
191          * This also means that we need to process LSB to MSB.
192          */
193         s64 upper = (val + (1 << 11)) >> 12;
194         /* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
195          * and addi are signed and RVC checks will perform signed comparisons.
196          */
197         s64 lower = ((val & 0xfff) << 52) >> 52;
198         int shift;
199
200         if (is_32b_int(val)) {
201                 if (upper)
202                         emit_lui(rd, upper, ctx);
203
204                 if (!upper) {
205                         emit_li(rd, lower, ctx);
206                         return;
207                 }
208
209                 emit_addiw(rd, rd, lower, ctx);
210                 return;
211         }
212
213         shift = __ffs(upper);
214         upper >>= shift;
215         shift += 12;
216
217         emit_imm(rd, upper, ctx);
218
219         emit_slli(rd, rd, shift, ctx);
220         if (lower)
221                 emit_addi(rd, rd, lower, ctx);
222 }
223
224 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
225 {
226         int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
227
228         if (seen_reg(RV_REG_RA, ctx)) {
229                 emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
230                 store_offset -= 8;
231         }
232         emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
233         store_offset -= 8;
234         if (seen_reg(RV_REG_S1, ctx)) {
235                 emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
236                 store_offset -= 8;
237         }
238         if (seen_reg(RV_REG_S2, ctx)) {
239                 emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
240                 store_offset -= 8;
241         }
242         if (seen_reg(RV_REG_S3, ctx)) {
243                 emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
244                 store_offset -= 8;
245         }
246         if (seen_reg(RV_REG_S4, ctx)) {
247                 emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
248                 store_offset -= 8;
249         }
250         if (seen_reg(RV_REG_S5, ctx)) {
251                 emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
252                 store_offset -= 8;
253         }
254         if (seen_reg(RV_REG_S6, ctx)) {
255                 emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
256                 store_offset -= 8;
257         }
258
259         emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
260         /* Set return value. */
261         if (!is_tail_call)
262                 emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
263         emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
264                   is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
265                   ctx);
266 }
267
268 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
269                      struct rv_jit_context *ctx)
270 {
271         switch (cond) {
272         case BPF_JEQ:
273                 emit(rv_beq(rd, rs, rvoff >> 1), ctx);
274                 return;
275         case BPF_JGT:
276                 emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
277                 return;
278         case BPF_JLT:
279                 emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
280                 return;
281         case BPF_JGE:
282                 emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
283                 return;
284         case BPF_JLE:
285                 emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
286                 return;
287         case BPF_JNE:
288                 emit(rv_bne(rd, rs, rvoff >> 1), ctx);
289                 return;
290         case BPF_JSGT:
291                 emit(rv_blt(rs, rd, rvoff >> 1), ctx);
292                 return;
293         case BPF_JSLT:
294                 emit(rv_blt(rd, rs, rvoff >> 1), ctx);
295                 return;
296         case BPF_JSGE:
297                 emit(rv_bge(rd, rs, rvoff >> 1), ctx);
298                 return;
299         case BPF_JSLE:
300                 emit(rv_bge(rs, rd, rvoff >> 1), ctx);
301         }
302 }
303
304 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
305                         struct rv_jit_context *ctx)
306 {
307         s64 upper, lower;
308
309         if (is_13b_int(rvoff)) {
310                 emit_bcc(cond, rd, rs, rvoff, ctx);
311                 return;
312         }
313
314         /* Adjust for jal */
315         rvoff -= 4;
316
317         /* Transform, e.g.:
318          *   bne rd,rs,foo
319          * to
320          *   beq rd,rs,<.L1>
321          *   (auipc foo)
322          *   jal(r) foo
323          * .L1
324          */
325         cond = invert_bpf_cond(cond);
326         if (is_21b_int(rvoff)) {
327                 emit_bcc(cond, rd, rs, 8, ctx);
328                 emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
329                 return;
330         }
331
332         /* 32b No need for an additional rvoff adjustment, since we
333          * get that from the auipc at PC', where PC = PC' + 4.
334          */
335         upper = (rvoff + (1 << 11)) >> 12;
336         lower = rvoff & 0xfff;
337
338         emit_bcc(cond, rd, rs, 12, ctx);
339         emit(rv_auipc(RV_REG_T1, upper), ctx);
340         emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
341 }
342
343 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
344 {
345         int tc_ninsn, off, start_insn = ctx->ninsns;
346         u8 tcc = rv_tail_call_reg(ctx);
347
348         /* a0: &ctx
349          * a1: &array
350          * a2: index
351          *
352          * if (index >= array->map.max_entries)
353          *      goto out;
354          */
355         tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
356                    ctx->offset[0];
357         emit_zextw(RV_REG_A2, RV_REG_A2, ctx);
358
359         off = offsetof(struct bpf_array, map.max_entries);
360         if (is_12b_check(off, insn))
361                 return -1;
362         emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
363         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
364         emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
365
366         /* if (--TCC < 0)
367          *     goto out;
368          */
369         emit_addi(RV_REG_TCC, tcc, -1, ctx);
370         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
371         emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
372
373         /* prog = array->ptrs[index];
374          * if (!prog)
375          *     goto out;
376          */
377         emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
378         emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
379         off = offsetof(struct bpf_array, ptrs);
380         if (is_12b_check(off, insn))
381                 return -1;
382         emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
383         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
384         emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
385
386         /* goto *(prog->bpf_func + 4); */
387         off = offsetof(struct bpf_prog, bpf_func);
388         if (is_12b_check(off, insn))
389                 return -1;
390         emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
391         __build_epilogue(true, ctx);
392         return 0;
393 }
394
395 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
396                       struct rv_jit_context *ctx)
397 {
398         u8 code = insn->code;
399
400         switch (code) {
401         case BPF_JMP | BPF_JA:
402         case BPF_JMP | BPF_CALL:
403         case BPF_JMP | BPF_EXIT:
404         case BPF_JMP | BPF_TAIL_CALL:
405                 break;
406         default:
407                 *rd = bpf_to_rv_reg(insn->dst_reg, ctx);
408         }
409
410         if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
411             code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
412             code & BPF_LDX || code & BPF_STX)
413                 *rs = bpf_to_rv_reg(insn->src_reg, ctx);
414 }
415
416 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
417                               struct rv_jit_context *ctx)
418 {
419         s64 upper, lower;
420
421         if (rvoff && fixed_addr && is_21b_int(rvoff)) {
422                 emit(rv_jal(rd, rvoff >> 1), ctx);
423                 return 0;
424         } else if (in_auipc_jalr_range(rvoff)) {
425                 upper = (rvoff + (1 << 11)) >> 12;
426                 lower = rvoff & 0xfff;
427                 emit(rv_auipc(RV_REG_T1, upper), ctx);
428                 emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
429                 return 0;
430         }
431
432         pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
433         return -ERANGE;
434 }
435
436 static bool is_signed_bpf_cond(u8 cond)
437 {
438         return cond == BPF_JSGT || cond == BPF_JSLT ||
439                 cond == BPF_JSGE || cond == BPF_JSLE;
440 }
441
442 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
443 {
444         s64 off = 0;
445         u64 ip;
446
447         if (addr && ctx->insns && ctx->ro_insns) {
448                 /*
449                  * Use the ro_insns(RX) to calculate the offset as the BPF
450                  * program will finally run from this memory region.
451                  */
452                 ip = (u64)(long)(ctx->ro_insns + ctx->ninsns);
453                 off = addr - ip;
454         }
455
456         return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
457 }
458
459 static inline void emit_kcfi(u32 hash, struct rv_jit_context *ctx)
460 {
461         if (IS_ENABLED(CONFIG_CFI_CLANG))
462                 emit(hash, ctx);
463 }
464
465 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
466                         struct rv_jit_context *ctx)
467 {
468         u8 r0;
469         int jmp_offset;
470
471         if (off) {
472                 if (is_12b_int(off)) {
473                         emit_addi(RV_REG_T1, rd, off, ctx);
474                 } else {
475                         emit_imm(RV_REG_T1, off, ctx);
476                         emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
477                 }
478                 rd = RV_REG_T1;
479         }
480
481         switch (imm) {
482         /* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
483         case BPF_ADD:
484                 emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
485                      rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
486                 break;
487         case BPF_AND:
488                 emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
489                      rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
490                 break;
491         case BPF_OR:
492                 emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
493                      rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
494                 break;
495         case BPF_XOR:
496                 emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
497                      rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
498                 break;
499         /* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
500         case BPF_ADD | BPF_FETCH:
501                 emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
502                      rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
503                 if (!is64)
504                         emit_zextw(rs, rs, ctx);
505                 break;
506         case BPF_AND | BPF_FETCH:
507                 emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
508                      rv_amoand_w(rs, rs, rd, 0, 0), ctx);
509                 if (!is64)
510                         emit_zextw(rs, rs, ctx);
511                 break;
512         case BPF_OR | BPF_FETCH:
513                 emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
514                      rv_amoor_w(rs, rs, rd, 0, 0), ctx);
515                 if (!is64)
516                         emit_zextw(rs, rs, ctx);
517                 break;
518         case BPF_XOR | BPF_FETCH:
519                 emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
520                      rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
521                 if (!is64)
522                         emit_zextw(rs, rs, ctx);
523                 break;
524         /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
525         case BPF_XCHG:
526                 emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
527                      rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
528                 if (!is64)
529                         emit_zextw(rs, rs, ctx);
530                 break;
531         /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
532         case BPF_CMPXCHG:
533                 r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
534                 emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
535                      rv_addiw(RV_REG_T2, r0, 0), ctx);
536                 emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
537                      rv_lr_w(r0, 0, rd, 0, 0), ctx);
538                 jmp_offset = ninsns_rvoff(8);
539                 emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
540                 emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
541                      rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
542                 jmp_offset = ninsns_rvoff(-6);
543                 emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
544                 emit(rv_fence(0x3, 0x3), ctx);
545                 break;
546         }
547 }
548
549 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
550 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
551
552 bool ex_handler_bpf(const struct exception_table_entry *ex,
553                     struct pt_regs *regs)
554 {
555         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
556         int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
557
558         *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
559         regs->epc = (unsigned long)&ex->fixup - offset;
560
561         return true;
562 }
563
564 /* For accesses to BTF pointers, add an entry to the exception table */
565 static int add_exception_handler(const struct bpf_insn *insn,
566                                  struct rv_jit_context *ctx,
567                                  int dst_reg, int insn_len)
568 {
569         struct exception_table_entry *ex;
570         unsigned long pc;
571         off_t ins_offset;
572         off_t fixup_offset;
573
574         if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
575             (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
576                 return 0;
577
578         if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
579                 return -EINVAL;
580
581         if (WARN_ON_ONCE(insn_len > ctx->ninsns))
582                 return -EINVAL;
583
584         if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
585                 return -EINVAL;
586
587         ex = &ctx->prog->aux->extable[ctx->nexentries];
588         pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len];
589
590         /*
591          * This is the relative offset of the instruction that may fault from
592          * the exception table itself. This will be written to the exception
593          * table and if this instruction faults, the destination register will
594          * be set to '0' and the execution will jump to the next instruction.
595          */
596         ins_offset = pc - (long)&ex->insn;
597         if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
598                 return -ERANGE;
599
600         /*
601          * Since the extable follows the program, the fixup offset is always
602          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
603          * to keep things simple, and put the destination register in the upper
604          * bits. We don't need to worry about buildtime or runtime sort
605          * modifying the upper bits because the table is already sorted, and
606          * isn't part of the main exception table.
607          *
608          * The fixup_offset is set to the next instruction from the instruction
609          * that may fault. The execution will jump to this after handling the
610          * fault.
611          */
612         fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
613         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
614                 return -ERANGE;
615
616         /*
617          * The offsets above have been calculated using the RO buffer but we
618          * need to use the R/W buffer for writes.
619          * switch ex to rw buffer for writing.
620          */
621         ex = (void *)ctx->insns + ((void *)ex - (void *)ctx->ro_insns);
622
623         ex->insn = ins_offset;
624
625         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
626                 FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
627         ex->type = EX_TYPE_BPF;
628
629         ctx->nexentries++;
630         return 0;
631 }
632
633 static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
634 {
635         s64 rvoff;
636         struct rv_jit_context ctx;
637
638         ctx.ninsns = 0;
639         ctx.insns = (u16 *)insns;
640
641         if (!target) {
642                 emit(rv_nop(), &ctx);
643                 emit(rv_nop(), &ctx);
644                 return 0;
645         }
646
647         rvoff = (s64)(target - ip);
648         return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
649 }
650
651 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
652                        void *old_addr, void *new_addr)
653 {
654         u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
655         bool is_call = poke_type == BPF_MOD_CALL;
656         int ret;
657
658         if (!is_kernel_text((unsigned long)ip) &&
659             !is_bpf_text_address((unsigned long)ip))
660                 return -ENOTSUPP;
661
662         ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
663         if (ret)
664                 return ret;
665
666         if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
667                 return -EFAULT;
668
669         ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
670         if (ret)
671                 return ret;
672
673         cpus_read_lock();
674         mutex_lock(&text_mutex);
675         if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
676                 ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
677         mutex_unlock(&text_mutex);
678         cpus_read_unlock();
679
680         return ret;
681 }
682
683 static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
684 {
685         int i;
686
687         for (i = 0; i < nregs; i++) {
688                 emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
689                 args_off -= 8;
690         }
691 }
692
693 static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
694 {
695         int i;
696
697         for (i = 0; i < nregs; i++) {
698                 emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
699                 args_off -= 8;
700         }
701 }
702
703 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
704                            int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
705 {
706         int ret, branch_off;
707         struct bpf_prog *p = l->link.prog;
708         int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
709
710         if (l->cookie) {
711                 emit_imm(RV_REG_T1, l->cookie, ctx);
712                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
713         } else {
714                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
715         }
716
717         /* arg1: prog */
718         emit_imm(RV_REG_A0, (const s64)p, ctx);
719         /* arg2: &run_ctx */
720         emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
721         ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
722         if (ret)
723                 return ret;
724
725         /* store prog start time */
726         emit_mv(RV_REG_S1, RV_REG_A0, ctx);
727
728         /* if (__bpf_prog_enter(prog) == 0)
729          *      goto skip_exec_of_prog;
730          */
731         branch_off = ctx->ninsns;
732         /* nop reserved for conditional jump */
733         emit(rv_nop(), ctx);
734
735         /* arg1: &args_off */
736         emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
737         if (!p->jited)
738                 /* arg2: progs[i]->insnsi for interpreter */
739                 emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
740         ret = emit_call((const u64)p->bpf_func, true, ctx);
741         if (ret)
742                 return ret;
743
744         if (save_ret) {
745                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
746                 emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
747         }
748
749         /* update branch with beqz */
750         if (ctx->insns) {
751                 int offset = ninsns_rvoff(ctx->ninsns - branch_off);
752                 u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
753                 *(u32 *)(ctx->insns + branch_off) = insn;
754         }
755
756         /* arg1: prog */
757         emit_imm(RV_REG_A0, (const s64)p, ctx);
758         /* arg2: prog start time */
759         emit_mv(RV_REG_A1, RV_REG_S1, ctx);
760         /* arg3: &run_ctx */
761         emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
762         ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
763
764         return ret;
765 }
766
767 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
768                                          const struct btf_func_model *m,
769                                          struct bpf_tramp_links *tlinks,
770                                          void *func_addr, u32 flags,
771                                          struct rv_jit_context *ctx)
772 {
773         int i, ret, offset;
774         int *branches_off = NULL;
775         int stack_size = 0, nregs = m->nr_args;
776         int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
777         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
778         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
779         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
780         bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
781         void *orig_call = func_addr;
782         bool save_ret;
783         u32 insn;
784
785         /* Two types of generated trampoline stack layout:
786          *
787          * 1. trampoline called from function entry
788          * --------------------------------------
789          * FP + 8           [ RA to parent func ] return address to parent
790          *                                        function
791          * FP + 0           [ FP of parent func ] frame pointer of parent
792          *                                        function
793          * FP - 8           [ T0 to traced func ] return address of traced
794          *                                        function
795          * FP - 16          [ FP of traced func ] frame pointer of traced
796          *                                        function
797          * --------------------------------------
798          *
799          * 2. trampoline called directly
800          * --------------------------------------
801          * FP - 8           [ RA to caller func ] return address to caller
802          *                                        function
803          * FP - 16          [ FP of caller func ] frame pointer of caller
804          *                                        function
805          * --------------------------------------
806          *
807          * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
808          *                                        BPF_TRAMP_F_RET_FENTRY_RET
809          *                  [ argN              ]
810          *                  [ ...               ]
811          * FP - args_off    [ arg1              ]
812          *
813          * FP - nregs_off   [ regs count        ]
814          *
815          * FP - ip_off      [ traced func       ] BPF_TRAMP_F_IP_ARG
816          *
817          * FP - run_ctx_off [ bpf_tramp_run_ctx ]
818          *
819          * FP - sreg_off    [ callee saved reg  ]
820          *
821          *                  [ pads              ] pads for 16 bytes alignment
822          */
823
824         if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
825                 return -ENOTSUPP;
826
827         /* extra regiters for struct arguments */
828         for (i = 0; i < m->nr_args; i++)
829                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
830                         nregs += round_up(m->arg_size[i], 8) / 8 - 1;
831
832         /* 8 arguments passed by registers */
833         if (nregs > 8)
834                 return -ENOTSUPP;
835
836         /* room of trampoline frame to store return address and frame pointer */
837         stack_size += 16;
838
839         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
840         if (save_ret) {
841                 stack_size += 16; /* Save both A5 (BPF R0) and A0 */
842                 retval_off = stack_size;
843         }
844
845         stack_size += nregs * 8;
846         args_off = stack_size;
847
848         stack_size += 8;
849         nregs_off = stack_size;
850
851         if (flags & BPF_TRAMP_F_IP_ARG) {
852                 stack_size += 8;
853                 ip_off = stack_size;
854         }
855
856         stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
857         run_ctx_off = stack_size;
858
859         stack_size += 8;
860         sreg_off = stack_size;
861
862         stack_size = round_up(stack_size, 16);
863
864         if (!is_struct_ops) {
865                 /* For the trampoline called from function entry,
866                  * the frame of traced function and the frame of
867                  * trampoline need to be considered.
868                  */
869                 emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
870                 emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
871                 emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
872                 emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
873
874                 emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
875                 emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
876                 emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
877                 emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
878         } else {
879                 /* emit kcfi hash */
880                 emit_kcfi(cfi_get_func_hash(func_addr), ctx);
881                 /* For the trampoline called directly, just handle
882                  * the frame of trampoline.
883                  */
884                 emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
885                 emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
886                 emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
887                 emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
888         }
889
890         /* callee saved register S1 to pass start time */
891         emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
892
893         /* store ip address of the traced function */
894         if (flags & BPF_TRAMP_F_IP_ARG) {
895                 emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
896                 emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
897         }
898
899         emit_li(RV_REG_T1, nregs, ctx);
900         emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
901
902         store_args(nregs, args_off, ctx);
903
904         /* skip to actual body of traced function */
905         if (flags & BPF_TRAMP_F_SKIP_FRAME)
906                 orig_call += RV_FENTRY_NINSNS * 4;
907
908         if (flags & BPF_TRAMP_F_CALL_ORIG) {
909                 emit_imm(RV_REG_A0, (const s64)im, ctx);
910                 ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
911                 if (ret)
912                         return ret;
913         }
914
915         for (i = 0; i < fentry->nr_links; i++) {
916                 ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
917                                       flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
918                 if (ret)
919                         return ret;
920         }
921
922         if (fmod_ret->nr_links) {
923                 branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
924                 if (!branches_off)
925                         return -ENOMEM;
926
927                 /* cleanup to avoid garbage return value confusion */
928                 emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
929                 for (i = 0; i < fmod_ret->nr_links; i++) {
930                         ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
931                                               run_ctx_off, true, ctx);
932                         if (ret)
933                                 goto out;
934                         emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
935                         branches_off[i] = ctx->ninsns;
936                         /* nop reserved for conditional jump */
937                         emit(rv_nop(), ctx);
938                 }
939         }
940
941         if (flags & BPF_TRAMP_F_CALL_ORIG) {
942                 restore_args(nregs, args_off, ctx);
943                 ret = emit_call((const u64)orig_call, true, ctx);
944                 if (ret)
945                         goto out;
946                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
947                 emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
948                 im->ip_after_call = ctx->insns + ctx->ninsns;
949                 /* 2 nops reserved for auipc+jalr pair */
950                 emit(rv_nop(), ctx);
951                 emit(rv_nop(), ctx);
952         }
953
954         /* update branches saved in invoke_bpf_mod_ret with bnez */
955         for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
956                 offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
957                 insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
958                 *(u32 *)(ctx->insns + branches_off[i]) = insn;
959         }
960
961         for (i = 0; i < fexit->nr_links; i++) {
962                 ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
963                                       run_ctx_off, false, ctx);
964                 if (ret)
965                         goto out;
966         }
967
968         if (flags & BPF_TRAMP_F_CALL_ORIG) {
969                 im->ip_epilogue = ctx->insns + ctx->ninsns;
970                 emit_imm(RV_REG_A0, (const s64)im, ctx);
971                 ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
972                 if (ret)
973                         goto out;
974         }
975
976         if (flags & BPF_TRAMP_F_RESTORE_REGS)
977                 restore_args(nregs, args_off, ctx);
978
979         if (save_ret) {
980                 emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
981                 emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx);
982         }
983
984         emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
985
986         if (!is_struct_ops) {
987                 /* trampoline called from function entry */
988                 emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
989                 emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
990                 emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
991
992                 emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
993                 emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
994                 emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
995
996                 if (flags & BPF_TRAMP_F_SKIP_FRAME)
997                         /* return to parent function */
998                         emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
999                 else
1000                         /* return to traced function */
1001                         emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
1002         } else {
1003                 /* trampoline called directly */
1004                 emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
1005                 emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1006                 emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1007
1008                 emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1009         }
1010
1011         ret = ctx->ninsns;
1012 out:
1013         kfree(branches_off);
1014         return ret;
1015 }
1016
1017 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
1018                              struct bpf_tramp_links *tlinks, void *func_addr)
1019 {
1020         struct bpf_tramp_image im;
1021         struct rv_jit_context ctx;
1022         int ret;
1023
1024         ctx.ninsns = 0;
1025         ctx.insns = NULL;
1026         ctx.ro_insns = NULL;
1027         ret = __arch_prepare_bpf_trampoline(&im, m, tlinks, func_addr, flags, &ctx);
1028
1029         return ret < 0 ? ret : ninsns_rvoff(ctx.ninsns);
1030 }
1031
1032 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
1033                                 void *image_end, const struct btf_func_model *m,
1034                                 u32 flags, struct bpf_tramp_links *tlinks,
1035                                 void *func_addr)
1036 {
1037         int ret;
1038         struct rv_jit_context ctx;
1039
1040         ctx.ninsns = 0;
1041         /*
1042          * The bpf_int_jit_compile() uses a RW buffer (ctx.insns) to write the
1043          * JITed instructions and later copies it to a RX region (ctx.ro_insns).
1044          * It also uses ctx.ro_insns to calculate offsets for jumps etc. As the
1045          * trampoline image uses the same memory area for writing and execution,
1046          * both ctx.insns and ctx.ro_insns can be set to image.
1047          */
1048         ctx.insns = image;
1049         ctx.ro_insns = image;
1050         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1051         if (ret < 0)
1052                 return ret;
1053
1054         bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1055
1056         return ninsns_rvoff(ret);
1057 }
1058
1059 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1060                       bool extra_pass)
1061 {
1062         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1063                     BPF_CLASS(insn->code) == BPF_JMP;
1064         int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1065         struct bpf_prog_aux *aux = ctx->prog->aux;
1066         u8 rd = -1, rs = -1, code = insn->code;
1067         s16 off = insn->off;
1068         s32 imm = insn->imm;
1069
1070         init_regs(&rd, &rs, insn, ctx);
1071
1072         switch (code) {
1073         /* dst = src */
1074         case BPF_ALU | BPF_MOV | BPF_X:
1075         case BPF_ALU64 | BPF_MOV | BPF_X:
1076                 if (imm == 1) {
1077                         /* Special mov32 for zext */
1078                         emit_zextw(rd, rd, ctx);
1079                         break;
1080                 }
1081                 switch (insn->off) {
1082                 case 0:
1083                         emit_mv(rd, rs, ctx);
1084                         break;
1085                 case 8:
1086                         emit_sextb(rd, rs, ctx);
1087                         break;
1088                 case 16:
1089                         emit_sexth(rd, rs, ctx);
1090                         break;
1091                 case 32:
1092                         emit_sextw(rd, rs, ctx);
1093                         break;
1094                 }
1095                 if (!is64 && !aux->verifier_zext)
1096                         emit_zextw(rd, rd, ctx);
1097                 break;
1098
1099         /* dst = dst OP src */
1100         case BPF_ALU | BPF_ADD | BPF_X:
1101         case BPF_ALU64 | BPF_ADD | BPF_X:
1102                 emit_add(rd, rd, rs, ctx);
1103                 if (!is64 && !aux->verifier_zext)
1104                         emit_zextw(rd, rd, ctx);
1105                 break;
1106         case BPF_ALU | BPF_SUB | BPF_X:
1107         case BPF_ALU64 | BPF_SUB | BPF_X:
1108                 if (is64)
1109                         emit_sub(rd, rd, rs, ctx);
1110                 else
1111                         emit_subw(rd, rd, rs, ctx);
1112
1113                 if (!is64 && !aux->verifier_zext)
1114                         emit_zextw(rd, rd, ctx);
1115                 break;
1116         case BPF_ALU | BPF_AND | BPF_X:
1117         case BPF_ALU64 | BPF_AND | BPF_X:
1118                 emit_and(rd, rd, rs, ctx);
1119                 if (!is64 && !aux->verifier_zext)
1120                         emit_zextw(rd, rd, ctx);
1121                 break;
1122         case BPF_ALU | BPF_OR | BPF_X:
1123         case BPF_ALU64 | BPF_OR | BPF_X:
1124                 emit_or(rd, rd, rs, ctx);
1125                 if (!is64 && !aux->verifier_zext)
1126                         emit_zextw(rd, rd, ctx);
1127                 break;
1128         case BPF_ALU | BPF_XOR | BPF_X:
1129         case BPF_ALU64 | BPF_XOR | BPF_X:
1130                 emit_xor(rd, rd, rs, ctx);
1131                 if (!is64 && !aux->verifier_zext)
1132                         emit_zextw(rd, rd, ctx);
1133                 break;
1134         case BPF_ALU | BPF_MUL | BPF_X:
1135         case BPF_ALU64 | BPF_MUL | BPF_X:
1136                 emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1137                 if (!is64 && !aux->verifier_zext)
1138                         emit_zextw(rd, rd, ctx);
1139                 break;
1140         case BPF_ALU | BPF_DIV | BPF_X:
1141         case BPF_ALU64 | BPF_DIV | BPF_X:
1142                 if (off)
1143                         emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
1144                 else
1145                         emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1146                 if (!is64 && !aux->verifier_zext)
1147                         emit_zextw(rd, rd, ctx);
1148                 break;
1149         case BPF_ALU | BPF_MOD | BPF_X:
1150         case BPF_ALU64 | BPF_MOD | BPF_X:
1151                 if (off)
1152                         emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
1153                 else
1154                         emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1155                 if (!is64 && !aux->verifier_zext)
1156                         emit_zextw(rd, rd, ctx);
1157                 break;
1158         case BPF_ALU | BPF_LSH | BPF_X:
1159         case BPF_ALU64 | BPF_LSH | BPF_X:
1160                 emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1161                 if (!is64 && !aux->verifier_zext)
1162                         emit_zextw(rd, rd, ctx);
1163                 break;
1164         case BPF_ALU | BPF_RSH | BPF_X:
1165         case BPF_ALU64 | BPF_RSH | BPF_X:
1166                 emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1167                 if (!is64 && !aux->verifier_zext)
1168                         emit_zextw(rd, rd, ctx);
1169                 break;
1170         case BPF_ALU | BPF_ARSH | BPF_X:
1171         case BPF_ALU64 | BPF_ARSH | BPF_X:
1172                 emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1173                 if (!is64 && !aux->verifier_zext)
1174                         emit_zextw(rd, rd, ctx);
1175                 break;
1176
1177         /* dst = -dst */
1178         case BPF_ALU | BPF_NEG:
1179         case BPF_ALU64 | BPF_NEG:
1180                 emit_sub(rd, RV_REG_ZERO, rd, ctx);
1181                 if (!is64 && !aux->verifier_zext)
1182                         emit_zextw(rd, rd, ctx);
1183                 break;
1184
1185         /* dst = BSWAP##imm(dst) */
1186         case BPF_ALU | BPF_END | BPF_FROM_LE:
1187                 switch (imm) {
1188                 case 16:
1189                         emit_zexth(rd, rd, ctx);
1190                         break;
1191                 case 32:
1192                         if (!aux->verifier_zext)
1193                                 emit_zextw(rd, rd, ctx);
1194                         break;
1195                 case 64:
1196                         /* Do nothing */
1197                         break;
1198                 }
1199                 break;
1200         case BPF_ALU | BPF_END | BPF_FROM_BE:
1201         case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1202                 emit_bswap(rd, imm, ctx);
1203                 break;
1204
1205         /* dst = imm */
1206         case BPF_ALU | BPF_MOV | BPF_K:
1207         case BPF_ALU64 | BPF_MOV | BPF_K:
1208                 emit_imm(rd, imm, ctx);
1209                 if (!is64 && !aux->verifier_zext)
1210                         emit_zextw(rd, rd, ctx);
1211                 break;
1212
1213         /* dst = dst OP imm */
1214         case BPF_ALU | BPF_ADD | BPF_K:
1215         case BPF_ALU64 | BPF_ADD | BPF_K:
1216                 if (is_12b_int(imm)) {
1217                         emit_addi(rd, rd, imm, ctx);
1218                 } else {
1219                         emit_imm(RV_REG_T1, imm, ctx);
1220                         emit_add(rd, rd, RV_REG_T1, ctx);
1221                 }
1222                 if (!is64 && !aux->verifier_zext)
1223                         emit_zextw(rd, rd, ctx);
1224                 break;
1225         case BPF_ALU | BPF_SUB | BPF_K:
1226         case BPF_ALU64 | BPF_SUB | BPF_K:
1227                 if (is_12b_int(-imm)) {
1228                         emit_addi(rd, rd, -imm, ctx);
1229                 } else {
1230                         emit_imm(RV_REG_T1, imm, ctx);
1231                         emit_sub(rd, rd, RV_REG_T1, ctx);
1232                 }
1233                 if (!is64 && !aux->verifier_zext)
1234                         emit_zextw(rd, rd, ctx);
1235                 break;
1236         case BPF_ALU | BPF_AND | BPF_K:
1237         case BPF_ALU64 | BPF_AND | BPF_K:
1238                 if (is_12b_int(imm)) {
1239                         emit_andi(rd, rd, imm, ctx);
1240                 } else {
1241                         emit_imm(RV_REG_T1, imm, ctx);
1242                         emit_and(rd, rd, RV_REG_T1, ctx);
1243                 }
1244                 if (!is64 && !aux->verifier_zext)
1245                         emit_zextw(rd, rd, ctx);
1246                 break;
1247         case BPF_ALU | BPF_OR | BPF_K:
1248         case BPF_ALU64 | BPF_OR | BPF_K:
1249                 if (is_12b_int(imm)) {
1250                         emit(rv_ori(rd, rd, imm), ctx);
1251                 } else {
1252                         emit_imm(RV_REG_T1, imm, ctx);
1253                         emit_or(rd, rd, RV_REG_T1, ctx);
1254                 }
1255                 if (!is64 && !aux->verifier_zext)
1256                         emit_zextw(rd, rd, ctx);
1257                 break;
1258         case BPF_ALU | BPF_XOR | BPF_K:
1259         case BPF_ALU64 | BPF_XOR | BPF_K:
1260                 if (is_12b_int(imm)) {
1261                         emit(rv_xori(rd, rd, imm), ctx);
1262                 } else {
1263                         emit_imm(RV_REG_T1, imm, ctx);
1264                         emit_xor(rd, rd, RV_REG_T1, ctx);
1265                 }
1266                 if (!is64 && !aux->verifier_zext)
1267                         emit_zextw(rd, rd, ctx);
1268                 break;
1269         case BPF_ALU | BPF_MUL | BPF_K:
1270         case BPF_ALU64 | BPF_MUL | BPF_K:
1271                 emit_imm(RV_REG_T1, imm, ctx);
1272                 emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1273                      rv_mulw(rd, rd, RV_REG_T1), ctx);
1274                 if (!is64 && !aux->verifier_zext)
1275                         emit_zextw(rd, rd, ctx);
1276                 break;
1277         case BPF_ALU | BPF_DIV | BPF_K:
1278         case BPF_ALU64 | BPF_DIV | BPF_K:
1279                 emit_imm(RV_REG_T1, imm, ctx);
1280                 if (off)
1281                         emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
1282                              rv_divw(rd, rd, RV_REG_T1), ctx);
1283                 else
1284                         emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1285                              rv_divuw(rd, rd, RV_REG_T1), ctx);
1286                 if (!is64 && !aux->verifier_zext)
1287                         emit_zextw(rd, rd, ctx);
1288                 break;
1289         case BPF_ALU | BPF_MOD | BPF_K:
1290         case BPF_ALU64 | BPF_MOD | BPF_K:
1291                 emit_imm(RV_REG_T1, imm, ctx);
1292                 if (off)
1293                         emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
1294                              rv_remw(rd, rd, RV_REG_T1), ctx);
1295                 else
1296                         emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1297                              rv_remuw(rd, rd, RV_REG_T1), ctx);
1298                 if (!is64 && !aux->verifier_zext)
1299                         emit_zextw(rd, rd, ctx);
1300                 break;
1301         case BPF_ALU | BPF_LSH | BPF_K:
1302         case BPF_ALU64 | BPF_LSH | BPF_K:
1303                 emit_slli(rd, rd, imm, ctx);
1304
1305                 if (!is64 && !aux->verifier_zext)
1306                         emit_zextw(rd, rd, ctx);
1307                 break;
1308         case BPF_ALU | BPF_RSH | BPF_K:
1309         case BPF_ALU64 | BPF_RSH | BPF_K:
1310                 if (is64)
1311                         emit_srli(rd, rd, imm, ctx);
1312                 else
1313                         emit(rv_srliw(rd, rd, imm), ctx);
1314
1315                 if (!is64 && !aux->verifier_zext)
1316                         emit_zextw(rd, rd, ctx);
1317                 break;
1318         case BPF_ALU | BPF_ARSH | BPF_K:
1319         case BPF_ALU64 | BPF_ARSH | BPF_K:
1320                 if (is64)
1321                         emit_srai(rd, rd, imm, ctx);
1322                 else
1323                         emit(rv_sraiw(rd, rd, imm), ctx);
1324
1325                 if (!is64 && !aux->verifier_zext)
1326                         emit_zextw(rd, rd, ctx);
1327                 break;
1328
1329         /* JUMP off */
1330         case BPF_JMP | BPF_JA:
1331         case BPF_JMP32 | BPF_JA:
1332                 if (BPF_CLASS(code) == BPF_JMP)
1333                         rvoff = rv_offset(i, off, ctx);
1334                 else
1335                         rvoff = rv_offset(i, imm, ctx);
1336                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1337                 if (ret)
1338                         return ret;
1339                 break;
1340
1341         /* IF (dst COND src) JUMP off */
1342         case BPF_JMP | BPF_JEQ | BPF_X:
1343         case BPF_JMP32 | BPF_JEQ | BPF_X:
1344         case BPF_JMP | BPF_JGT | BPF_X:
1345         case BPF_JMP32 | BPF_JGT | BPF_X:
1346         case BPF_JMP | BPF_JLT | BPF_X:
1347         case BPF_JMP32 | BPF_JLT | BPF_X:
1348         case BPF_JMP | BPF_JGE | BPF_X:
1349         case BPF_JMP32 | BPF_JGE | BPF_X:
1350         case BPF_JMP | BPF_JLE | BPF_X:
1351         case BPF_JMP32 | BPF_JLE | BPF_X:
1352         case BPF_JMP | BPF_JNE | BPF_X:
1353         case BPF_JMP32 | BPF_JNE | BPF_X:
1354         case BPF_JMP | BPF_JSGT | BPF_X:
1355         case BPF_JMP32 | BPF_JSGT | BPF_X:
1356         case BPF_JMP | BPF_JSLT | BPF_X:
1357         case BPF_JMP32 | BPF_JSLT | BPF_X:
1358         case BPF_JMP | BPF_JSGE | BPF_X:
1359         case BPF_JMP32 | BPF_JSGE | BPF_X:
1360         case BPF_JMP | BPF_JSLE | BPF_X:
1361         case BPF_JMP32 | BPF_JSLE | BPF_X:
1362         case BPF_JMP | BPF_JSET | BPF_X:
1363         case BPF_JMP32 | BPF_JSET | BPF_X:
1364                 rvoff = rv_offset(i, off, ctx);
1365                 if (!is64) {
1366                         s = ctx->ninsns;
1367                         if (is_signed_bpf_cond(BPF_OP(code))) {
1368                                 emit_sextw_alt(&rs, RV_REG_T1, ctx);
1369                                 emit_sextw_alt(&rd, RV_REG_T2, ctx);
1370                         } else {
1371                                 emit_zextw_alt(&rs, RV_REG_T1, ctx);
1372                                 emit_zextw_alt(&rd, RV_REG_T2, ctx);
1373                         }
1374                         e = ctx->ninsns;
1375
1376                         /* Adjust for extra insns */
1377                         rvoff -= ninsns_rvoff(e - s);
1378                 }
1379
1380                 if (BPF_OP(code) == BPF_JSET) {
1381                         /* Adjust for and */
1382                         rvoff -= 4;
1383                         emit_and(RV_REG_T1, rd, rs, ctx);
1384                         emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1385                 } else {
1386                         emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1387                 }
1388                 break;
1389
1390         /* IF (dst COND imm) JUMP off */
1391         case BPF_JMP | BPF_JEQ | BPF_K:
1392         case BPF_JMP32 | BPF_JEQ | BPF_K:
1393         case BPF_JMP | BPF_JGT | BPF_K:
1394         case BPF_JMP32 | BPF_JGT | BPF_K:
1395         case BPF_JMP | BPF_JLT | BPF_K:
1396         case BPF_JMP32 | BPF_JLT | BPF_K:
1397         case BPF_JMP | BPF_JGE | BPF_K:
1398         case BPF_JMP32 | BPF_JGE | BPF_K:
1399         case BPF_JMP | BPF_JLE | BPF_K:
1400         case BPF_JMP32 | BPF_JLE | BPF_K:
1401         case BPF_JMP | BPF_JNE | BPF_K:
1402         case BPF_JMP32 | BPF_JNE | BPF_K:
1403         case BPF_JMP | BPF_JSGT | BPF_K:
1404         case BPF_JMP32 | BPF_JSGT | BPF_K:
1405         case BPF_JMP | BPF_JSLT | BPF_K:
1406         case BPF_JMP32 | BPF_JSLT | BPF_K:
1407         case BPF_JMP | BPF_JSGE | BPF_K:
1408         case BPF_JMP32 | BPF_JSGE | BPF_K:
1409         case BPF_JMP | BPF_JSLE | BPF_K:
1410         case BPF_JMP32 | BPF_JSLE | BPF_K:
1411                 rvoff = rv_offset(i, off, ctx);
1412                 s = ctx->ninsns;
1413                 if (imm)
1414                         emit_imm(RV_REG_T1, imm, ctx);
1415                 rs = imm ? RV_REG_T1 : RV_REG_ZERO;
1416                 if (!is64) {
1417                         if (is_signed_bpf_cond(BPF_OP(code))) {
1418                                 emit_sextw_alt(&rd, RV_REG_T2, ctx);
1419                                 /* rs has been sign extended */
1420                         } else {
1421                                 emit_zextw_alt(&rd, RV_REG_T2, ctx);
1422                                 if (imm)
1423                                         emit_zextw(rs, rs, ctx);
1424                         }
1425                 }
1426                 e = ctx->ninsns;
1427
1428                 /* Adjust for extra insns */
1429                 rvoff -= ninsns_rvoff(e - s);
1430                 emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1431                 break;
1432
1433         case BPF_JMP | BPF_JSET | BPF_K:
1434         case BPF_JMP32 | BPF_JSET | BPF_K:
1435                 rvoff = rv_offset(i, off, ctx);
1436                 s = ctx->ninsns;
1437                 if (is_12b_int(imm)) {
1438                         emit_andi(RV_REG_T1, rd, imm, ctx);
1439                 } else {
1440                         emit_imm(RV_REG_T1, imm, ctx);
1441                         emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1442                 }
1443                 /* For jset32, we should clear the upper 32 bits of t1, but
1444                  * sign-extension is sufficient here and saves one instruction,
1445                  * as t1 is used only in comparison against zero.
1446                  */
1447                 if (!is64 && imm < 0)
1448                         emit_sextw(RV_REG_T1, RV_REG_T1, ctx);
1449                 e = ctx->ninsns;
1450                 rvoff -= ninsns_rvoff(e - s);
1451                 emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1452                 break;
1453
1454         /* function call */
1455         case BPF_JMP | BPF_CALL:
1456         {
1457                 bool fixed_addr;
1458                 u64 addr;
1459
1460                 mark_call(ctx);
1461                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1462                                             &addr, &fixed_addr);
1463                 if (ret < 0)
1464                         return ret;
1465
1466                 if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
1467                         const struct btf_func_model *fm;
1468                         int idx;
1469
1470                         fm = bpf_jit_find_kfunc_model(ctx->prog, insn);
1471                         if (!fm)
1472                                 return -EINVAL;
1473
1474                         for (idx = 0; idx < fm->nr_args; idx++) {
1475                                 u8 reg = bpf_to_rv_reg(BPF_REG_1 + idx, ctx);
1476
1477                                 if (fm->arg_size[idx] == sizeof(int))
1478                                         emit_sextw(reg, reg, ctx);
1479                         }
1480                 }
1481
1482                 ret = emit_call(addr, fixed_addr, ctx);
1483                 if (ret)
1484                         return ret;
1485
1486                 if (insn->src_reg != BPF_PSEUDO_CALL)
1487                         emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1488                 break;
1489         }
1490         /* tail call */
1491         case BPF_JMP | BPF_TAIL_CALL:
1492                 if (emit_bpf_tail_call(i, ctx))
1493                         return -1;
1494                 break;
1495
1496         /* function return */
1497         case BPF_JMP | BPF_EXIT:
1498                 if (i == ctx->prog->len - 1)
1499                         break;
1500
1501                 rvoff = epilogue_offset(ctx);
1502                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1503                 if (ret)
1504                         return ret;
1505                 break;
1506
1507         /* dst = imm64 */
1508         case BPF_LD | BPF_IMM | BPF_DW:
1509         {
1510                 struct bpf_insn insn1 = insn[1];
1511                 u64 imm64;
1512
1513                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
1514                 if (bpf_pseudo_func(insn)) {
1515                         /* fixed-length insns for extra jit pass */
1516                         ret = emit_addr(rd, imm64, extra_pass, ctx);
1517                         if (ret)
1518                                 return ret;
1519                 } else {
1520                         emit_imm(rd, imm64, ctx);
1521                 }
1522
1523                 return 1;
1524         }
1525
1526         /* LDX: dst = *(unsigned size *)(src + off) */
1527         case BPF_LDX | BPF_MEM | BPF_B:
1528         case BPF_LDX | BPF_MEM | BPF_H:
1529         case BPF_LDX | BPF_MEM | BPF_W:
1530         case BPF_LDX | BPF_MEM | BPF_DW:
1531         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1532         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1533         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1534         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1535         /* LDSX: dst = *(signed size *)(src + off) */
1536         case BPF_LDX | BPF_MEMSX | BPF_B:
1537         case BPF_LDX | BPF_MEMSX | BPF_H:
1538         case BPF_LDX | BPF_MEMSX | BPF_W:
1539         case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1540         case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1541         case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1542         {
1543                 int insn_len, insns_start;
1544                 bool sign_ext;
1545
1546                 sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
1547                            BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
1548
1549                 switch (BPF_SIZE(code)) {
1550                 case BPF_B:
1551                         if (is_12b_int(off)) {
1552                                 insns_start = ctx->ninsns;
1553                                 if (sign_ext)
1554                                         emit(rv_lb(rd, off, rs), ctx);
1555                                 else
1556                                         emit(rv_lbu(rd, off, rs), ctx);
1557                                 insn_len = ctx->ninsns - insns_start;
1558                                 break;
1559                         }
1560
1561                         emit_imm(RV_REG_T1, off, ctx);
1562                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1563                         insns_start = ctx->ninsns;
1564                         if (sign_ext)
1565                                 emit(rv_lb(rd, 0, RV_REG_T1), ctx);
1566                         else
1567                                 emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1568                         insn_len = ctx->ninsns - insns_start;
1569                         break;
1570                 case BPF_H:
1571                         if (is_12b_int(off)) {
1572                                 insns_start = ctx->ninsns;
1573                                 if (sign_ext)
1574                                         emit(rv_lh(rd, off, rs), ctx);
1575                                 else
1576                                         emit(rv_lhu(rd, off, rs), ctx);
1577                                 insn_len = ctx->ninsns - insns_start;
1578                                 break;
1579                         }
1580
1581                         emit_imm(RV_REG_T1, off, ctx);
1582                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1583                         insns_start = ctx->ninsns;
1584                         if (sign_ext)
1585                                 emit(rv_lh(rd, 0, RV_REG_T1), ctx);
1586                         else
1587                                 emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1588                         insn_len = ctx->ninsns - insns_start;
1589                         break;
1590                 case BPF_W:
1591                         if (is_12b_int(off)) {
1592                                 insns_start = ctx->ninsns;
1593                                 if (sign_ext)
1594                                         emit(rv_lw(rd, off, rs), ctx);
1595                                 else
1596                                         emit(rv_lwu(rd, off, rs), ctx);
1597                                 insn_len = ctx->ninsns - insns_start;
1598                                 break;
1599                         }
1600
1601                         emit_imm(RV_REG_T1, off, ctx);
1602                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1603                         insns_start = ctx->ninsns;
1604                         if (sign_ext)
1605                                 emit(rv_lw(rd, 0, RV_REG_T1), ctx);
1606                         else
1607                                 emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1608                         insn_len = ctx->ninsns - insns_start;
1609                         break;
1610                 case BPF_DW:
1611                         if (is_12b_int(off)) {
1612                                 insns_start = ctx->ninsns;
1613                                 emit_ld(rd, off, rs, ctx);
1614                                 insn_len = ctx->ninsns - insns_start;
1615                                 break;
1616                         }
1617
1618                         emit_imm(RV_REG_T1, off, ctx);
1619                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1620                         insns_start = ctx->ninsns;
1621                         emit_ld(rd, 0, RV_REG_T1, ctx);
1622                         insn_len = ctx->ninsns - insns_start;
1623                         break;
1624                 }
1625
1626                 ret = add_exception_handler(insn, ctx, rd, insn_len);
1627                 if (ret)
1628                         return ret;
1629
1630                 if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
1631                         return 1;
1632                 break;
1633         }
1634         /* speculation barrier */
1635         case BPF_ST | BPF_NOSPEC:
1636                 break;
1637
1638         /* ST: *(size *)(dst + off) = imm */
1639         case BPF_ST | BPF_MEM | BPF_B:
1640                 emit_imm(RV_REG_T1, imm, ctx);
1641                 if (is_12b_int(off)) {
1642                         emit(rv_sb(rd, off, RV_REG_T1), ctx);
1643                         break;
1644                 }
1645
1646                 emit_imm(RV_REG_T2, off, ctx);
1647                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1648                 emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1649                 break;
1650
1651         case BPF_ST | BPF_MEM | BPF_H:
1652                 emit_imm(RV_REG_T1, imm, ctx);
1653                 if (is_12b_int(off)) {
1654                         emit(rv_sh(rd, off, RV_REG_T1), ctx);
1655                         break;
1656                 }
1657
1658                 emit_imm(RV_REG_T2, off, ctx);
1659                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1660                 emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1661                 break;
1662         case BPF_ST | BPF_MEM | BPF_W:
1663                 emit_imm(RV_REG_T1, imm, ctx);
1664                 if (is_12b_int(off)) {
1665                         emit_sw(rd, off, RV_REG_T1, ctx);
1666                         break;
1667                 }
1668
1669                 emit_imm(RV_REG_T2, off, ctx);
1670                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1671                 emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1672                 break;
1673         case BPF_ST | BPF_MEM | BPF_DW:
1674                 emit_imm(RV_REG_T1, imm, ctx);
1675                 if (is_12b_int(off)) {
1676                         emit_sd(rd, off, RV_REG_T1, ctx);
1677                         break;
1678                 }
1679
1680                 emit_imm(RV_REG_T2, off, ctx);
1681                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1682                 emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1683                 break;
1684
1685         /* STX: *(size *)(dst + off) = src */
1686         case BPF_STX | BPF_MEM | BPF_B:
1687                 if (is_12b_int(off)) {
1688                         emit(rv_sb(rd, off, rs), ctx);
1689                         break;
1690                 }
1691
1692                 emit_imm(RV_REG_T1, off, ctx);
1693                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1694                 emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1695                 break;
1696         case BPF_STX | BPF_MEM | BPF_H:
1697                 if (is_12b_int(off)) {
1698                         emit(rv_sh(rd, off, rs), ctx);
1699                         break;
1700                 }
1701
1702                 emit_imm(RV_REG_T1, off, ctx);
1703                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1704                 emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1705                 break;
1706         case BPF_STX | BPF_MEM | BPF_W:
1707                 if (is_12b_int(off)) {
1708                         emit_sw(rd, off, rs, ctx);
1709                         break;
1710                 }
1711
1712                 emit_imm(RV_REG_T1, off, ctx);
1713                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1714                 emit_sw(RV_REG_T1, 0, rs, ctx);
1715                 break;
1716         case BPF_STX | BPF_MEM | BPF_DW:
1717                 if (is_12b_int(off)) {
1718                         emit_sd(rd, off, rs, ctx);
1719                         break;
1720                 }
1721
1722                 emit_imm(RV_REG_T1, off, ctx);
1723                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1724                 emit_sd(RV_REG_T1, 0, rs, ctx);
1725                 break;
1726         case BPF_STX | BPF_ATOMIC | BPF_W:
1727         case BPF_STX | BPF_ATOMIC | BPF_DW:
1728                 emit_atomic(rd, rs, off, imm,
1729                             BPF_SIZE(code) == BPF_DW, ctx);
1730                 break;
1731         default:
1732                 pr_err("bpf-jit: unknown opcode %02x\n", code);
1733                 return -EINVAL;
1734         }
1735
1736         return 0;
1737 }
1738
1739 void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
1740 {
1741         int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1742
1743         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1744         if (bpf_stack_adjust)
1745                 mark_fp(ctx);
1746
1747         if (seen_reg(RV_REG_RA, ctx))
1748                 stack_adjust += 8;
1749         stack_adjust += 8; /* RV_REG_FP */
1750         if (seen_reg(RV_REG_S1, ctx))
1751                 stack_adjust += 8;
1752         if (seen_reg(RV_REG_S2, ctx))
1753                 stack_adjust += 8;
1754         if (seen_reg(RV_REG_S3, ctx))
1755                 stack_adjust += 8;
1756         if (seen_reg(RV_REG_S4, ctx))
1757                 stack_adjust += 8;
1758         if (seen_reg(RV_REG_S5, ctx))
1759                 stack_adjust += 8;
1760         if (seen_reg(RV_REG_S6, ctx))
1761                 stack_adjust += 8;
1762
1763         stack_adjust = round_up(stack_adjust, 16);
1764         stack_adjust += bpf_stack_adjust;
1765
1766         store_offset = stack_adjust - 8;
1767
1768         /* emit kcfi type preamble immediately before the  first insn */
1769         emit_kcfi(is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash, ctx);
1770
1771         /* nops reserved for auipc+jalr pair */
1772         for (i = 0; i < RV_FENTRY_NINSNS; i++)
1773                 emit(rv_nop(), ctx);
1774
1775         /* First instruction is always setting the tail-call-counter
1776          * (TCC) register. This instruction is skipped for tail calls.
1777          * Force using a 4-byte (non-compressed) instruction.
1778          */
1779         emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1780
1781         emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1782
1783         if (seen_reg(RV_REG_RA, ctx)) {
1784                 emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1785                 store_offset -= 8;
1786         }
1787         emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1788         store_offset -= 8;
1789         if (seen_reg(RV_REG_S1, ctx)) {
1790                 emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1791                 store_offset -= 8;
1792         }
1793         if (seen_reg(RV_REG_S2, ctx)) {
1794                 emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1795                 store_offset -= 8;
1796         }
1797         if (seen_reg(RV_REG_S3, ctx)) {
1798                 emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1799                 store_offset -= 8;
1800         }
1801         if (seen_reg(RV_REG_S4, ctx)) {
1802                 emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1803                 store_offset -= 8;
1804         }
1805         if (seen_reg(RV_REG_S5, ctx)) {
1806                 emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1807                 store_offset -= 8;
1808         }
1809         if (seen_reg(RV_REG_S6, ctx)) {
1810                 emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1811                 store_offset -= 8;
1812         }
1813
1814         emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1815
1816         if (bpf_stack_adjust)
1817                 emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1818
1819         /* Program contains calls and tail calls, so RV_REG_TCC need
1820          * to be saved across calls.
1821          */
1822         if (seen_tail_call(ctx) && seen_call(ctx))
1823                 emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1824
1825         ctx->stack_size = stack_adjust;
1826 }
1827
1828 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1829 {
1830         __build_epilogue(false, ctx);
1831 }
1832
1833 bool bpf_jit_supports_kfunc_call(void)
1834 {
1835         return true;
1836 }
1837
1838 bool bpf_jit_supports_ptr_xchg(void)
1839 {
1840         return true;
1841 }