diff options
Diffstat (limited to 'arch/x86/net/bpf_jit_comp.c')
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 198 |
1 files changed, 120 insertions, 78 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 5e680e039d0e..7913440c0fd4 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -432,7 +432,7 @@ static void emit_return(u8 **pprog, u8 *ip) u8 *prog = *pprog; if (cpu_feature_enabled(X86_FEATURE_RETHUNK)) { - emit_jump(&prog, &__x86_return_thunk, ip); + emit_jump(&prog, x86_return_thunk, ip); } else { EMIT1(0xC3); /* ret */ if (IS_ENABLED(CONFIG_SLS)) @@ -893,6 +893,10 @@ static void emit_nops(u8 **pprog, int len) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) +/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ +#define RESTORE_TAIL_CALL_CNT(stack) \ + EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) + static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, int oldproglen, struct jit_context *ctx, bool jmp_padding) { @@ -1436,9 +1440,7 @@ st: if (is_imm8(insn->off)) case BPF_JMP | BPF_CALL: func = (u8 *) __bpf_call_base + imm32; if (tail_call_reachable) { - /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ - EMIT3_off32(0x48, 0x8B, 0x85, - -round_up(bpf_prog->aux->stack_depth, 8) - 8); + RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7)) return -EINVAL; } else { @@ -1623,16 +1625,24 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */ break; case BPF_JMP | BPF_JA: - if (insn->off == -1) - /* -1 jmp instructions will always jump - * backwards two bytes. Explicitly handling - * this case avoids wasting too many passes - * when there are long sequences of replaced - * dead code. - */ - jmp_offset = -2; - else - jmp_offset = addrs[i + insn->off] - addrs[i]; + case BPF_JMP32 | BPF_JA: + if (BPF_CLASS(insn->code) == BPF_JMP) { + if (insn->off == -1) + /* -1 jmp instructions will always jump + * backwards two bytes. Explicitly handling + * this case avoids wasting too many passes + * when there are long sequences of replaced + * dead code. + */ + jmp_offset = -2; + else + jmp_offset = addrs[i + insn->off] - addrs[i]; + } else { + if (insn->imm == -1) + jmp_offset = -2; + else + jmp_offset = addrs[i + insn->imm] - addrs[i]; + } if (!jmp_offset) { /* @@ -1750,63 +1760,37 @@ emit_jmp: return proglen; } -static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args, +static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_regs, int stack_size) { - int i, j, arg_size, nr_regs; + int i; + /* Store function arguments to stack. * For a function that accepts two pointers the sequence will be: * mov QWORD PTR [rbp-0x10],rdi * mov QWORD PTR [rbp-0x8],rsi */ - for (i = 0, j = 0; i < min(nr_args, 6); i++) { - if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) { - nr_regs = (m->arg_size[i] + 7) / 8; - arg_size = 8; - } else { - nr_regs = 1; - arg_size = m->arg_size[i]; - } - - while (nr_regs) { - emit_stx(prog, bytes_to_bpf_size(arg_size), - BPF_REG_FP, - j == 5 ? X86_REG_R9 : BPF_REG_1 + j, - -(stack_size - j * 8)); - nr_regs--; - j++; - } - } + for (i = 0; i < min(nr_regs, 6); i++) + emit_stx(prog, BPF_DW, BPF_REG_FP, + i == 5 ? X86_REG_R9 : BPF_REG_1 + i, + -(stack_size - i * 8)); } -static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args, +static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_regs, int stack_size) { - int i, j, arg_size, nr_regs; + int i; /* Restore function arguments from stack. * For a function that accepts two pointers the sequence will be: * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10] * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8] */ - for (i = 0, j = 0; i < min(nr_args, 6); i++) { - if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) { - nr_regs = (m->arg_size[i] + 7) / 8; - arg_size = 8; - } else { - nr_regs = 1; - arg_size = m->arg_size[i]; - } - - while (nr_regs) { - emit_ldx(prog, bytes_to_bpf_size(arg_size), - j == 5 ? X86_REG_R9 : BPF_REG_1 + j, - BPF_REG_FP, - -(stack_size - j * 8)); - nr_regs--; - j++; - } - } + for (i = 0; i < min(nr_regs, 6); i++) + emit_ldx(prog, BPF_DW, + i == 5 ? X86_REG_R9 : BPF_REG_1 + i, + BPF_REG_FP, + -(stack_size - i * 8)); } static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog, @@ -2031,8 +2015,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i struct bpf_tramp_links *tlinks, void *func_addr) { - int ret, i, nr_args = m->nr_args, extra_nregs = 0; - int regs_off, ip_off, args_off, stack_size = nr_args * 8, run_ctx_off; + int i, ret, nr_regs = m->nr_args, stack_size = 0; + int regs_off, nregs_off, ip_off, run_ctx_off; struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT]; struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN]; @@ -2041,17 +2025,14 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i u8 *prog; bool save_ret; - /* x86-64 supports up to 6 arguments. 7+ can be added in the future */ - if (nr_args > 6) - return -ENOTSUPP; - - for (i = 0; i < MAX_BPF_FUNC_ARGS; i++) { + /* extra registers for struct arguments */ + for (i = 0; i < m->nr_args; i++) if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG) - extra_nregs += (m->arg_size[i] + 7) / 8 - 1; - } - if (nr_args + extra_nregs > 6) + nr_regs += (m->arg_size[i] + 7) / 8 - 1; + + /* x86-64 supports up to 6 arguments. 7+ can be added in the future */ + if (nr_regs > 6) return -ENOTSUPP; - stack_size += extra_nregs * 8; /* Generated trampoline stack layout: * @@ -2065,11 +2046,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i * [ ... ] * RBP - regs_off [ reg_arg1 ] program's ctx pointer * - * RBP - args_off [ arg regs count ] always + * RBP - nregs_off [ regs count ] always * * RBP - ip_off [ traced function ] BPF_TRAMP_F_IP_ARG flag * * RBP - run_ctx_off [ bpf_tramp_run_ctx ] + * RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX */ /* room for return value of orig_call or fentry prog */ @@ -2077,11 +2059,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i if (save_ret) stack_size += 8; + stack_size += nr_regs * 8; regs_off = stack_size; - /* args count */ + /* regs count */ stack_size += 8; - args_off = stack_size; + nregs_off = stack_size; if (flags & BPF_TRAMP_F_IP_ARG) stack_size += 8; /* room for IP address argument */ @@ -2106,14 +2089,16 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i EMIT1(0x55); /* push rbp */ EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ EMIT4(0x48, 0x83, 0xEC, stack_size); /* sub rsp, stack_size */ + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + EMIT1(0x50); /* push rax */ EMIT1(0x53); /* push rbx */ /* Store number of argument registers of the traced function: - * mov rax, nr_args + extra_nregs - * mov QWORD PTR [rbp - args_off], rax + * mov rax, nr_regs + * mov QWORD PTR [rbp - nregs_off], rax */ - emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_args + extra_nregs); - emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -args_off); + emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs); + emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off); if (flags & BPF_TRAMP_F_IP_ARG) { /* Store IP address of the traced function: @@ -2124,7 +2109,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -ip_off); } - save_regs(m, &prog, nr_args, regs_off); + save_regs(m, &prog, nr_regs, regs_off); if (flags & BPF_TRAMP_F_CALL_ORIG) { /* arg1: mov rdi, im */ @@ -2154,11 +2139,17 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i } if (flags & BPF_TRAMP_F_CALL_ORIG) { - restore_regs(m, &prog, nr_args, regs_off); + restore_regs(m, &prog, nr_regs, regs_off); + + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + /* Before calling the original function, restore the + * tail_call_cnt from stack to rax. + */ + RESTORE_TAIL_CALL_CNT(stack_size); if (flags & BPF_TRAMP_F_ORIG_STACK) { - emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8); - EMIT2(0xff, 0xd0); /* call *rax */ + emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8); + EMIT2(0xff, 0xd3); /* call *rbx */ } else { /* call original function */ if (emit_call(&prog, orig_call, prog)) { @@ -2195,7 +2186,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i } if (flags & BPF_TRAMP_F_RESTORE_REGS) - restore_regs(m, &prog, nr_args, regs_off); + restore_regs(m, &prog, nr_regs, regs_off); /* This needs to be done regardless. If there were fmod_ret programs, * the return value is only updated on the stack and still needs to be @@ -2209,7 +2200,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i ret = -EINVAL; goto cleanup; } - } + } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + /* Before running the original function, restore the + * tail_call_cnt from stack to rax. + */ + RESTORE_TAIL_CALL_CNT(stack_size); + /* restore return value of orig_call or fentry prog back into RAX */ if (save_ret) emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8); @@ -2553,3 +2549,49 @@ void bpf_jit_free(struct bpf_prog *prog) bpf_prog_unlock_free(prog); } + +void bpf_arch_poke_desc_update(struct bpf_jit_poke_descriptor *poke, + struct bpf_prog *new, struct bpf_prog *old) +{ + u8 *old_addr, *new_addr, *old_bypass_addr; + int ret; + + old_bypass_addr = old ? NULL : poke->bypass_addr; + old_addr = old ? (u8 *)old->bpf_func + poke->adj_off : NULL; + new_addr = new ? (u8 *)new->bpf_func + poke->adj_off : NULL; + + /* + * On program loading or teardown, the program's kallsym entry + * might not be in place, so we use __bpf_arch_text_poke to skip + * the kallsyms check. + */ + if (new) { + ret = __bpf_arch_text_poke(poke->tailcall_target, + BPF_MOD_JUMP, + old_addr, new_addr); + BUG_ON(ret < 0); + if (!old) { + ret = __bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + poke->bypass_addr, + NULL); + BUG_ON(ret < 0); + } + } else { + ret = __bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + old_bypass_addr, + poke->bypass_addr); + BUG_ON(ret < 0); + /* let other CPUs finish the execution of program + * so that it will not possible to expose them + * to invalid nop, stack unwind, nop state + */ + if (!ret) + synchronize_rcu(); + ret = __bpf_arch_text_poke(poke->tailcall_target, + BPF_MOD_JUMP, + old_addr, NULL); + BUG_ON(ret < 0); + } +} |