bpf,x86: Simplify computing label offsets
commit dceba0817ca329868a15e2e1dd46eb6340b69206 upstream. Take an idea from the 32bit JIT, which uses the multi-pass nature of the JIT to compute the instruction offsets on a prior pass in order to compute the relative jump offsets on a later pass. Application to the x86_64 JIT is slightly more involved because the offsets depend on program variables (such as callee_regs_used and stack_depth) and hence the computed offsets need to be kept in the context of the JIT. This removes, IMO quite fragile, code that hard-codes the offsets and tries to compute the length of variable parts of it. Convert both emit_bpf_tail_call_*() functions which have an out: label at the end. Additionally emit_bpt_tail_call_direct() also has a poke table entry, for which it computes the offset from the end (and thus already relies on the previous pass to have computed addrs[i]), also convert this to be a forward based offset. Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org> Reviewed-by: Borislav Petkov <bp@suse.de> Acked-by: Alexei Starovoitov <ast@kernel.org> Acked-by: Josh Poimboeuf <jpoimboe@redhat.com> Tested-by: Alexei Starovoitov <ast@kernel.org> Link: https://lore.kernel.org/r/20211026120310.552304864@infradead.org Signed-off-by: Thadeu Lima de Souza Cascardo <cascardo@canonical.com> [bwh: Backported to 5.10: keep the cnt variable in emit_bpf_tail_call_{,in}direct()] Signed-off-by: Ben Hutchings <ben@decadent.org.uk> Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
This commit is contained in:

committed by
Greg Kroah-Hartman

parent
38a80a3ca2
commit
1713e5c4f8
@@ -212,6 +212,14 @@ static void jit_fill_hole(void *area, unsigned int size)
|
|||||||
|
|
||||||
struct jit_context {
|
struct jit_context {
|
||||||
int cleanup_addr; /* Epilogue code offset */
|
int cleanup_addr; /* Epilogue code offset */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Program specific offsets of labels in the code; these rely on the
|
||||||
|
* JIT doing at least 2 passes, recording the position on the first
|
||||||
|
* pass, only to generate the correct offset on the second pass.
|
||||||
|
*/
|
||||||
|
int tail_call_direct_label;
|
||||||
|
int tail_call_indirect_label;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Maximum number of bytes emitted while JITing one eBPF insn */
|
/* Maximum number of bytes emitted while JITing one eBPF insn */
|
||||||
@@ -371,22 +379,6 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
|
|||||||
return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
|
return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
static int get_pop_bytes(bool *callee_regs_used)
|
|
||||||
{
|
|
||||||
int bytes = 0;
|
|
||||||
|
|
||||||
if (callee_regs_used[3])
|
|
||||||
bytes += 2;
|
|
||||||
if (callee_regs_used[2])
|
|
||||||
bytes += 2;
|
|
||||||
if (callee_regs_used[1])
|
|
||||||
bytes += 2;
|
|
||||||
if (callee_regs_used[0])
|
|
||||||
bytes += 1;
|
|
||||||
|
|
||||||
return bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Generate the following code:
|
* Generate the following code:
|
||||||
*
|
*
|
||||||
@@ -402,30 +394,12 @@ static int get_pop_bytes(bool *callee_regs_used)
|
|||||||
* out:
|
* out:
|
||||||
*/
|
*/
|
||||||
static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
|
static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
|
||||||
u32 stack_depth)
|
u32 stack_depth, u8 *ip,
|
||||||
|
struct jit_context *ctx)
|
||||||
{
|
{
|
||||||
int tcc_off = -4 - round_up(stack_depth, 8);
|
int tcc_off = -4 - round_up(stack_depth, 8);
|
||||||
u8 *prog = *pprog;
|
u8 *prog = *pprog, *start = *pprog;
|
||||||
int pop_bytes = 0;
|
int cnt = 0, offset;
|
||||||
int off1 = 42;
|
|
||||||
int off2 = 31;
|
|
||||||
int off3 = 9;
|
|
||||||
int cnt = 0;
|
|
||||||
|
|
||||||
/* count the additional bytes used for popping callee regs from stack
|
|
||||||
* that need to be taken into account for each of the offsets that
|
|
||||||
* are used for bailing out of the tail call
|
|
||||||
*/
|
|
||||||
pop_bytes = get_pop_bytes(callee_regs_used);
|
|
||||||
off1 += pop_bytes;
|
|
||||||
off2 += pop_bytes;
|
|
||||||
off3 += pop_bytes;
|
|
||||||
|
|
||||||
if (stack_depth) {
|
|
||||||
off1 += 7;
|
|
||||||
off2 += 7;
|
|
||||||
off3 += 7;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* rdi - pointer to ctx
|
* rdi - pointer to ctx
|
||||||
@@ -440,8 +414,9 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
|
|||||||
EMIT2(0x89, 0xD2); /* mov edx, edx */
|
EMIT2(0x89, 0xD2); /* mov edx, edx */
|
||||||
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
|
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
|
||||||
offsetof(struct bpf_array, map.max_entries));
|
offsetof(struct bpf_array, map.max_entries));
|
||||||
#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
|
|
||||||
EMIT2(X86_JBE, OFFSET1); /* jbe out */
|
offset = ctx->tail_call_indirect_label - (prog + 2 - start);
|
||||||
|
EMIT2(X86_JBE, offset); /* jbe out */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
|
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
|
||||||
@@ -449,8 +424,9 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
|
|||||||
*/
|
*/
|
||||||
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
|
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
|
||||||
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
|
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
|
||||||
#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
|
|
||||||
EMIT2(X86_JA, OFFSET2); /* ja out */
|
offset = ctx->tail_call_indirect_label - (prog + 2 - start);
|
||||||
|
EMIT2(X86_JA, offset); /* ja out */
|
||||||
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
|
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
|
||||||
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
|
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
|
||||||
|
|
||||||
@@ -463,12 +439,11 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
|
|||||||
* goto out;
|
* goto out;
|
||||||
*/
|
*/
|
||||||
EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
|
EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
|
||||||
#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
|
|
||||||
EMIT2(X86_JE, OFFSET3); /* je out */
|
|
||||||
|
|
||||||
*pprog = prog;
|
offset = ctx->tail_call_indirect_label - (prog + 2 - start);
|
||||||
pop_callee_regs(pprog, callee_regs_used);
|
EMIT2(X86_JE, offset); /* je out */
|
||||||
prog = *pprog;
|
|
||||||
|
pop_callee_regs(&prog, callee_regs_used);
|
||||||
|
|
||||||
EMIT1(0x58); /* pop rax */
|
EMIT1(0x58); /* pop rax */
|
||||||
if (stack_depth)
|
if (stack_depth)
|
||||||
@@ -488,39 +463,18 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
|
|||||||
RETPOLINE_RCX_BPF_JIT();
|
RETPOLINE_RCX_BPF_JIT();
|
||||||
|
|
||||||
/* out: */
|
/* out: */
|
||||||
|
ctx->tail_call_indirect_label = prog - start;
|
||||||
*pprog = prog;
|
*pprog = prog;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
|
static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
|
||||||
u8 **pprog, int addr, u8 *image,
|
u8 **pprog, u8 *ip,
|
||||||
bool *callee_regs_used, u32 stack_depth)
|
bool *callee_regs_used, u32 stack_depth,
|
||||||
|
struct jit_context *ctx)
|
||||||
{
|
{
|
||||||
int tcc_off = -4 - round_up(stack_depth, 8);
|
int tcc_off = -4 - round_up(stack_depth, 8);
|
||||||
u8 *prog = *pprog;
|
u8 *prog = *pprog, *start = *pprog;
|
||||||
int pop_bytes = 0;
|
int cnt = 0, offset;
|
||||||
int off1 = 20;
|
|
||||||
int poke_off;
|
|
||||||
int cnt = 0;
|
|
||||||
|
|
||||||
/* count the additional bytes used for popping callee regs to stack
|
|
||||||
* that need to be taken into account for jump offset that is used for
|
|
||||||
* bailing out from of the tail call when limit is reached
|
|
||||||
*/
|
|
||||||
pop_bytes = get_pop_bytes(callee_regs_used);
|
|
||||||
off1 += pop_bytes;
|
|
||||||
|
|
||||||
/*
|
|
||||||
* total bytes for:
|
|
||||||
* - nop5/ jmpq $off
|
|
||||||
* - pop callee regs
|
|
||||||
* - sub rsp, $val if depth > 0
|
|
||||||
* - pop rax
|
|
||||||
*/
|
|
||||||
poke_off = X86_PATCH_SIZE + pop_bytes + 1;
|
|
||||||
if (stack_depth) {
|
|
||||||
poke_off += 7;
|
|
||||||
off1 += 7;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
|
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
|
||||||
@@ -528,28 +482,30 @@ static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
|
|||||||
*/
|
*/
|
||||||
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
|
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
|
||||||
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
|
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
|
||||||
EMIT2(X86_JA, off1); /* ja out */
|
|
||||||
|
offset = ctx->tail_call_direct_label - (prog + 2 - start);
|
||||||
|
EMIT2(X86_JA, offset); /* ja out */
|
||||||
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
|
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
|
||||||
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
|
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
|
||||||
|
|
||||||
poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
|
poke->tailcall_bypass = ip + (prog - start);
|
||||||
poke->adj_off = X86_TAIL_CALL_OFFSET;
|
poke->adj_off = X86_TAIL_CALL_OFFSET;
|
||||||
poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
|
poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
|
||||||
poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
|
poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
|
||||||
|
|
||||||
emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
|
emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
|
||||||
poke->tailcall_bypass);
|
poke->tailcall_bypass);
|
||||||
|
|
||||||
*pprog = prog;
|
pop_callee_regs(&prog, callee_regs_used);
|
||||||
pop_callee_regs(pprog, callee_regs_used);
|
|
||||||
prog = *pprog;
|
|
||||||
EMIT1(0x58); /* pop rax */
|
EMIT1(0x58); /* pop rax */
|
||||||
if (stack_depth)
|
if (stack_depth)
|
||||||
EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
|
EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
|
||||||
|
|
||||||
memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
|
memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
|
||||||
prog += X86_PATCH_SIZE;
|
prog += X86_PATCH_SIZE;
|
||||||
|
|
||||||
/* out: */
|
/* out: */
|
||||||
|
ctx->tail_call_direct_label = prog - start;
|
||||||
|
|
||||||
*pprog = prog;
|
*pprog = prog;
|
||||||
}
|
}
|
||||||
@@ -1274,13 +1230,16 @@ xadd: if (is_imm8(insn->off))
|
|||||||
case BPF_JMP | BPF_TAIL_CALL:
|
case BPF_JMP | BPF_TAIL_CALL:
|
||||||
if (imm32)
|
if (imm32)
|
||||||
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
|
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
|
||||||
&prog, addrs[i], image,
|
&prog, image + addrs[i - 1],
|
||||||
callee_regs_used,
|
callee_regs_used,
|
||||||
bpf_prog->aux->stack_depth);
|
bpf_prog->aux->stack_depth,
|
||||||
|
ctx);
|
||||||
else
|
else
|
||||||
emit_bpf_tail_call_indirect(&prog,
|
emit_bpf_tail_call_indirect(&prog,
|
||||||
callee_regs_used,
|
callee_regs_used,
|
||||||
bpf_prog->aux->stack_depth);
|
bpf_prog->aux->stack_depth,
|
||||||
|
image + addrs[i - 1],
|
||||||
|
ctx);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
/* cond jump */
|
/* cond jump */
|
||||||
|
Reference in New Issue
Block a user