linux_dsm_epyc7002/arch/x86/net/bpf_jit_comp.c
Alexei Starovoitov 60b58afc96 bpf: fix net.core.bpf_jit_enable race
global bpf_jit_enable variable is tested multiple times in JITs,
blinding and verifier core. The malicious root can try to toggle
it while loading the programs. This race condition was accounted
for and there should be no issues, but it's safer to avoid
this race condition.

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Daniel Borkmann <daniel@iogearbox.net>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
2017-12-17 20:34:36 +01:00

1208 lines
31 KiB
C

/* bpf_jit_comp.c : BPF JIT compiler
*
* Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
* Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; version 2
* of the License.
*/
#include <linux/netdevice.h>
#include <linux/filter.h>
#include <linux/if_vlan.h>
#include <asm/cacheflush.h>
#include <asm/set_memory.h>
#include <linux/bpf.h>
int bpf_jit_enable __read_mostly;
/*
* assembly code in arch/x86/net/bpf_jit.S
*/
extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
extern u8 sk_load_byte_positive_offset[];
extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
extern u8 sk_load_byte_negative_offset[];
static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
{
if (len == 1)
*ptr = bytes;
else if (len == 2)
*(u16 *)ptr = bytes;
else {
*(u32 *)ptr = bytes;
barrier();
}
return ptr + len;
}
#define EMIT(bytes, len) \
do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
#define EMIT1(b1) EMIT(b1, 1)
#define EMIT2(b1, b2) EMIT((b1) + ((b2) << 8), 2)
#define EMIT3(b1, b2, b3) EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
#define EMIT4(b1, b2, b3, b4) EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
#define EMIT1_off32(b1, off) \
do {EMIT1(b1); EMIT(off, 4); } while (0)
#define EMIT2_off32(b1, b2, off) \
do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
#define EMIT3_off32(b1, b2, b3, off) \
do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
#define EMIT4_off32(b1, b2, b3, b4, off) \
do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
static bool is_imm8(int value)
{
return value <= 127 && value >= -128;
}
static bool is_simm32(s64 value)
{
return value == (s64) (s32) value;
}
/* mov dst, src */
#define EMIT_mov(DST, SRC) \
do {if (DST != SRC) \
EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
} while (0)
static int bpf_size_to_x86_bytes(int bpf_size)
{
if (bpf_size == BPF_W)
return 4;
else if (bpf_size == BPF_H)
return 2;
else if (bpf_size == BPF_B)
return 1;
else if (bpf_size == BPF_DW)
return 4; /* imm32 */
else
return 0;
}
/* list of x86 cond jumps opcodes (. + s8)
* Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
*/
#define X86_JB 0x72
#define X86_JAE 0x73
#define X86_JE 0x74
#define X86_JNE 0x75
#define X86_JBE 0x76
#define X86_JA 0x77
#define X86_JL 0x7C
#define X86_JGE 0x7D
#define X86_JLE 0x7E
#define X86_JG 0x7F
static void bpf_flush_icache(void *start, void *end)
{
mm_segment_t old_fs = get_fs();
set_fs(KERNEL_DS);
smp_wmb();
flush_icache_range((unsigned long)start, (unsigned long)end);
set_fs(old_fs);
}
#define CHOOSE_LOAD_FUNC(K, func) \
((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
/* pick a register outside of BPF range for JIT internal work */
#define AUX_REG (MAX_BPF_JIT_REG + 1)
/* The following table maps BPF registers to x64 registers.
*
* x64 register r12 is unused, since if used as base address
* register in load/store instructions, it always needs an
* extra byte of encoding and is callee saved.
*
* r9 caches skb->len - skb->data_len
* r10 caches skb->data, and used for blinding (if enabled)
*/
static const int reg2hex[] = {
[BPF_REG_0] = 0, /* rax */
[BPF_REG_1] = 7, /* rdi */
[BPF_REG_2] = 6, /* rsi */
[BPF_REG_3] = 2, /* rdx */
[BPF_REG_4] = 1, /* rcx */
[BPF_REG_5] = 0, /* r8 */
[BPF_REG_6] = 3, /* rbx callee saved */
[BPF_REG_7] = 5, /* r13 callee saved */
[BPF_REG_8] = 6, /* r14 callee saved */
[BPF_REG_9] = 7, /* r15 callee saved */
[BPF_REG_FP] = 5, /* rbp readonly */
[BPF_REG_AX] = 2, /* r10 temp register */
[AUX_REG] = 3, /* r11 temp register */
};
/* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
* which need extra byte of encoding.
* rax,rcx,...,rbp have simpler encoding
*/
static bool is_ereg(u32 reg)
{
return (1 << reg) & (BIT(BPF_REG_5) |
BIT(AUX_REG) |
BIT(BPF_REG_7) |
BIT(BPF_REG_8) |
BIT(BPF_REG_9) |
BIT(BPF_REG_AX));
}
/* add modifiers if 'reg' maps to x64 registers r8..r15 */
static u8 add_1mod(u8 byte, u32 reg)
{
if (is_ereg(reg))
byte |= 1;
return byte;
}
static u8 add_2mod(u8 byte, u32 r1, u32 r2)
{
if (is_ereg(r1))
byte |= 1;
if (is_ereg(r2))
byte |= 4;
return byte;
}
/* encode 'dst_reg' register into x64 opcode 'byte' */
static u8 add_1reg(u8 byte, u32 dst_reg)
{
return byte + reg2hex[dst_reg];
}
/* encode 'dst_reg' and 'src_reg' registers into x64 opcode 'byte' */
static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
{
return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
}
static void jit_fill_hole(void *area, unsigned int size)
{
/* fill whole space with int3 instructions */
memset(area, 0xcc, size);
}
struct jit_context {
int cleanup_addr; /* epilogue code offset */
bool seen_ld_abs;
bool seen_ax_reg;
};
/* maximum number of bytes emitted while JITing one eBPF insn */
#define BPF_MAX_INSN_SIZE 128
#define BPF_INSN_SAFETY 64
#define AUX_STACK_SPACE \
(32 /* space for rbx, r13, r14, r15 */ + \
8 /* space for skb_copy_bits() buffer */)
#define PROLOGUE_SIZE 37
/* emit x64 prologue code for BPF program and check it's size.
* bpf_tail_call helper will skip it while jumping into another program
*/
static void emit_prologue(u8 **pprog, u32 stack_depth)
{
u8 *prog = *pprog;
int cnt = 0;
EMIT1(0x55); /* push rbp */
EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
/* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */
EMIT3_off32(0x48, 0x81, 0xEC,
round_up(stack_depth, 8) + AUX_STACK_SPACE);
/* sub rbp, AUX_STACK_SPACE */
EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE);
/* all classic BPF filters use R6(rbx) save it */
/* mov qword ptr [rbp+0],rbx */
EMIT4(0x48, 0x89, 0x5D, 0);
/* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
* as temporary, so all tcpdump filters need to spill/fill R7(r13) and
* R8(r14). R9(r15) spill could be made conditional, but there is only
* one 'bpf_error' return path out of helper functions inside bpf_jit.S
* The overhead of extra spill is negligible for any filter other
* than synthetic ones. Therefore not worth adding complexity.
*/
/* mov qword ptr [rbp+8],r13 */
EMIT4(0x4C, 0x89, 0x6D, 8);
/* mov qword ptr [rbp+16],r14 */
EMIT4(0x4C, 0x89, 0x75, 16);
/* mov qword ptr [rbp+24],r15 */
EMIT4(0x4C, 0x89, 0x7D, 24);
/* Clear the tail call counter (tail_call_cnt): for eBPF tail calls
* we need to reset the counter to 0. It's done in two instructions,
* resetting rax register to 0 (xor on eax gets 0 extended), and
* moving it to the counter location.
*/
/* xor eax, eax */
EMIT2(0x31, 0xc0);
/* mov qword ptr [rbp+32], rax */
EMIT4(0x48, 0x89, 0x45, 32);
BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
*pprog = prog;
}
/* generate the following code:
* ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
* if (index >= array->map.max_entries)
* goto out;
* if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
* prog = array->ptrs[index];
* if (prog == NULL)
* goto out;
* goto *(prog->bpf_func + prologue_size);
* out:
*/
static void emit_bpf_tail_call(u8 **pprog)
{
u8 *prog = *pprog;
int label1, label2, label3;
int cnt = 0;
/* rdi - pointer to ctx
* rsi - pointer to bpf_array
* rdx - index in bpf_array
*/
/* if (index >= array->map.max_entries)
* goto out;
*/
EMIT2(0x89, 0xD2); /* mov edx, edx */
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
offsetof(struct bpf_array, map.max_entries));
#define OFFSET1 43 /* number of bytes to jump */
EMIT2(X86_JBE, OFFSET1); /* jbe out */
label1 = cnt;
/* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
* goto out;
*/
EMIT2_off32(0x8B, 0x85, 36); /* mov eax, dword ptr [rbp + 36] */
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
#define OFFSET2 32
EMIT2(X86_JA, OFFSET2); /* ja out */
label2 = cnt;
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
EMIT2_off32(0x89, 0x85, 36); /* mov dword ptr [rbp + 36], eax */
/* prog = array->ptrs[index]; */
EMIT4_off32(0x48, 0x8B, 0x84, 0xD6, /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
offsetof(struct bpf_array, ptrs));
/* if (prog == NULL)
* goto out;
*/
EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */
#define OFFSET3 10
EMIT2(X86_JE, OFFSET3); /* je out */
label3 = cnt;
/* goto *(prog->bpf_func + prologue_size); */
EMIT4(0x48, 0x8B, 0x40, /* mov rax, qword ptr [rax + 32] */
offsetof(struct bpf_prog, bpf_func));
EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE); /* add rax, prologue_size */
/* now we're ready to jump into next BPF program
* rdi == ctx (1st arg)
* rax == prog->bpf_func + prologue_size
*/
EMIT2(0xFF, 0xE0); /* jmp rax */
/* out: */
BUILD_BUG_ON(cnt - label1 != OFFSET1);
BUILD_BUG_ON(cnt - label2 != OFFSET2);
BUILD_BUG_ON(cnt - label3 != OFFSET3);
*pprog = prog;
}
static void emit_load_skb_data_hlen(u8 **pprog)
{
u8 *prog = *pprog;
int cnt = 0;
/* r9d = skb->len - skb->data_len (headlen)
* r10 = skb->data
*/
/* mov %r9d, off32(%rdi) */
EMIT3_off32(0x44, 0x8b, 0x8f, offsetof(struct sk_buff, len));
/* sub %r9d, off32(%rdi) */
EMIT3_off32(0x44, 0x2b, 0x8f, offsetof(struct sk_buff, data_len));
/* mov %r10, off32(%rdi) */
EMIT3_off32(0x4c, 0x8b, 0x97, offsetof(struct sk_buff, data));
*pprog = prog;
}
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int oldproglen, struct jit_context *ctx)
{
struct bpf_insn *insn = bpf_prog->insnsi;
int insn_cnt = bpf_prog->len;
bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
bool seen_ax_reg = ctx->seen_ax_reg | (oldproglen == 0);
bool seen_exit = false;
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
int i, cnt = 0;
int proglen = 0;
u8 *prog = temp;
emit_prologue(&prog, bpf_prog->aux->stack_depth);
if (seen_ld_abs)
emit_load_skb_data_hlen(&prog);
for (i = 0; i < insn_cnt; i++, insn++) {
const s32 imm32 = insn->imm;
u32 dst_reg = insn->dst_reg;
u32 src_reg = insn->src_reg;
u8 b1 = 0, b2 = 0, b3 = 0;
s64 jmp_offset;
u8 jmp_cond;
bool reload_skb_data;
int ilen;
u8 *func;
if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
ctx->seen_ax_reg = seen_ax_reg = true;
switch (insn->code) {
/* ALU */
case BPF_ALU | BPF_ADD | BPF_X:
case BPF_ALU | BPF_SUB | BPF_X:
case BPF_ALU | BPF_AND | BPF_X:
case BPF_ALU | BPF_OR | BPF_X:
case BPF_ALU | BPF_XOR | BPF_X:
case BPF_ALU64 | BPF_ADD | BPF_X:
case BPF_ALU64 | BPF_SUB | BPF_X:
case BPF_ALU64 | BPF_AND | BPF_X:
case BPF_ALU64 | BPF_OR | BPF_X:
case BPF_ALU64 | BPF_XOR | BPF_X:
switch (BPF_OP(insn->code)) {
case BPF_ADD: b2 = 0x01; break;
case BPF_SUB: b2 = 0x29; break;
case BPF_AND: b2 = 0x21; break;
case BPF_OR: b2 = 0x09; break;
case BPF_XOR: b2 = 0x31; break;
}
if (BPF_CLASS(insn->code) == BPF_ALU64)
EMIT1(add_2mod(0x48, dst_reg, src_reg));
else if (is_ereg(dst_reg) || is_ereg(src_reg))
EMIT1(add_2mod(0x40, dst_reg, src_reg));
EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
break;
/* mov dst, src */
case BPF_ALU64 | BPF_MOV | BPF_X:
EMIT_mov(dst_reg, src_reg);
break;
/* mov32 dst, src */
case BPF_ALU | BPF_MOV | BPF_X:
if (is_ereg(dst_reg) || is_ereg(src_reg))
EMIT1(add_2mod(0x40, dst_reg, src_reg));
EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
break;
/* neg dst */
case BPF_ALU | BPF_NEG:
case BPF_ALU64 | BPF_NEG:
if (BPF_CLASS(insn->code) == BPF_ALU64)
EMIT1(add_1mod(0x48, dst_reg));
else if (is_ereg(dst_reg))
EMIT1(add_1mod(0x40, dst_reg));
EMIT2(0xF7, add_1reg(0xD8, dst_reg));
break;
case BPF_ALU | BPF_ADD | BPF_K:
case BPF_ALU | BPF_SUB | BPF_K:
case BPF_ALU | BPF_AND | BPF_K:
case BPF_ALU | BPF_OR | BPF_K:
case BPF_ALU | BPF_XOR | BPF_K:
case BPF_ALU64 | BPF_ADD | BPF_K:
case BPF_ALU64 | BPF_SUB | BPF_K:
case BPF_ALU64 | BPF_AND | BPF_K:
case BPF_ALU64 | BPF_OR | BPF_K:
case BPF_ALU64 | BPF_XOR | BPF_K:
if (BPF_CLASS(insn->code) == BPF_ALU64)
EMIT1(add_1mod(0x48, dst_reg));
else if (is_ereg(dst_reg))
EMIT1(add_1mod(0x40, dst_reg));
switch (BPF_OP(insn->code)) {
case BPF_ADD: b3 = 0xC0; break;
case BPF_SUB: b3 = 0xE8; break;
case BPF_AND: b3 = 0xE0; break;
case BPF_OR: b3 = 0xC8; break;
case BPF_XOR: b3 = 0xF0; break;
}
if (is_imm8(imm32))
EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
else
EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
break;
case BPF_ALU64 | BPF_MOV | BPF_K:
/* optimization: if imm32 is positive,
* use 'mov eax, imm32' (which zero-extends imm32)
* to save 2 bytes
*/
if (imm32 < 0) {
/* 'mov rax, imm32' sign extends imm32 */
b1 = add_1mod(0x48, dst_reg);
b2 = 0xC7;
b3 = 0xC0;
EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
break;
}
case BPF_ALU | BPF_MOV | BPF_K:
/* optimization: if imm32 is zero, use 'xor <dst>,<dst>'
* to save 3 bytes.
*/
if (imm32 == 0) {
if (is_ereg(dst_reg))
EMIT1(add_2mod(0x40, dst_reg, dst_reg));
b2 = 0x31; /* xor */
b3 = 0xC0;
EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
break;
}
/* mov %eax, imm32 */
if (is_ereg(dst_reg))
EMIT1(add_1mod(0x40, dst_reg));
EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
break;
case BPF_LD | BPF_IMM | BPF_DW:
/* optimization: if imm64 is zero, use 'xor <dst>,<dst>'
* to save 7 bytes.
*/
if (insn[0].imm == 0 && insn[1].imm == 0) {
b1 = add_2mod(0x48, dst_reg, dst_reg);
b2 = 0x31; /* xor */
b3 = 0xC0;
EMIT3(b1, b2, add_2reg(b3, dst_reg, dst_reg));
insn++;
i++;
break;
}
/* movabsq %rax, imm64 */
EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
EMIT(insn[0].imm, 4);
EMIT(insn[1].imm, 4);
insn++;
i++;
break;
/* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
case BPF_ALU | BPF_MOD | BPF_X:
case BPF_ALU | BPF_DIV | BPF_X:
case BPF_ALU | BPF_MOD | BPF_K:
case BPF_ALU | BPF_DIV | BPF_K:
case BPF_ALU64 | BPF_MOD | BPF_X:
case BPF_ALU64 | BPF_DIV | BPF_X:
case BPF_ALU64 | BPF_MOD | BPF_K:
case BPF_ALU64 | BPF_DIV | BPF_K:
EMIT1(0x50); /* push rax */
EMIT1(0x52); /* push rdx */
if (BPF_SRC(insn->code) == BPF_X)
/* mov r11, src_reg */
EMIT_mov(AUX_REG, src_reg);
else
/* mov r11, imm32 */
EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
/* mov rax, dst_reg */
EMIT_mov(BPF_REG_0, dst_reg);
/* xor edx, edx
* equivalent to 'xor rdx, rdx', but one byte less
*/
EMIT2(0x31, 0xd2);
if (BPF_SRC(insn->code) == BPF_X) {
/* if (src_reg == 0) return 0 */
/* cmp r11, 0 */
EMIT4(0x49, 0x83, 0xFB, 0x00);
/* jne .+9 (skip over pop, pop, xor and jmp) */
EMIT2(X86_JNE, 1 + 1 + 2 + 5);
EMIT1(0x5A); /* pop rdx */
EMIT1(0x58); /* pop rax */
EMIT2(0x31, 0xc0); /* xor eax, eax */
/* jmp cleanup_addr
* addrs[i] - 11, because there are 11 bytes
* after this insn: div, mov, pop, pop, mov
*/
jmp_offset = ctx->cleanup_addr - (addrs[i] - 11);
EMIT1_off32(0xE9, jmp_offset);
}
if (BPF_CLASS(insn->code) == BPF_ALU64)
/* div r11 */
EMIT3(0x49, 0xF7, 0xF3);
else
/* div r11d */
EMIT3(0x41, 0xF7, 0xF3);
if (BPF_OP(insn->code) == BPF_MOD)
/* mov r11, rdx */
EMIT3(0x49, 0x89, 0xD3);
else
/* mov r11, rax */
EMIT3(0x49, 0x89, 0xC3);
EMIT1(0x5A); /* pop rdx */
EMIT1(0x58); /* pop rax */
/* mov dst_reg, r11 */
EMIT_mov(dst_reg, AUX_REG);
break;
case BPF_ALU | BPF_MUL | BPF_K:
case BPF_ALU | BPF_MUL | BPF_X:
case BPF_ALU64 | BPF_MUL | BPF_K:
case BPF_ALU64 | BPF_MUL | BPF_X:
EMIT1(0x50); /* push rax */
EMIT1(0x52); /* push rdx */
/* mov r11, dst_reg */
EMIT_mov(AUX_REG, dst_reg);
if (BPF_SRC(insn->code) == BPF_X)
/* mov rax, src_reg */
EMIT_mov(BPF_REG_0, src_reg);
else
/* mov rax, imm32 */
EMIT3_off32(0x48, 0xC7, 0xC0, imm32);
if (BPF_CLASS(insn->code) == BPF_ALU64)
EMIT1(add_1mod(0x48, AUX_REG));
else if (is_ereg(AUX_REG))
EMIT1(add_1mod(0x40, AUX_REG));
/* mul(q) r11 */
EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
/* mov r11, rax */
EMIT_mov(AUX_REG, BPF_REG_0);
EMIT1(0x5A); /* pop rdx */
EMIT1(0x58); /* pop rax */
/* mov dst_reg, r11 */
EMIT_mov(dst_reg, AUX_REG);
break;
/* shifts */
case BPF_ALU | BPF_LSH | BPF_K:
case BPF_ALU | BPF_RSH | BPF_K:
case BPF_ALU | BPF_ARSH | BPF_K:
case BPF_ALU64 | BPF_LSH | BPF_K:
case BPF_ALU64 | BPF_RSH | BPF_K:
case BPF_ALU64 | BPF_ARSH | BPF_K:
if (BPF_CLASS(insn->code) == BPF_ALU64)
EMIT1(add_1mod(0x48, dst_reg));
else if (is_ereg(dst_reg))
EMIT1(add_1mod(0x40, dst_reg));
switch (BPF_OP(insn->code)) {
case BPF_LSH: b3 = 0xE0; break;
case BPF_RSH: b3 = 0xE8; break;
case BPF_ARSH: b3 = 0xF8; break;
}
EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
break;
case BPF_ALU | BPF_LSH | BPF_X:
case BPF_ALU | BPF_RSH | BPF_X:
case BPF_ALU | BPF_ARSH | BPF_X:
case BPF_ALU64 | BPF_LSH | BPF_X:
case BPF_ALU64 | BPF_RSH | BPF_X:
case BPF_ALU64 | BPF_ARSH | BPF_X:
/* check for bad case when dst_reg == rcx */
if (dst_reg == BPF_REG_4) {
/* mov r11, dst_reg */
EMIT_mov(AUX_REG, dst_reg);
dst_reg = AUX_REG;
}
if (src_reg != BPF_REG_4) { /* common case */
EMIT1(0x51); /* push rcx */
/* mov rcx, src_reg */
EMIT_mov(BPF_REG_4, src_reg);
}
/* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
if (BPF_CLASS(insn->code) == BPF_ALU64)
EMIT1(add_1mod(0x48, dst_reg));
else if (is_ereg(dst_reg))
EMIT1(add_1mod(0x40, dst_reg));
switch (BPF_OP(insn->code)) {
case BPF_LSH: b3 = 0xE0; break;
case BPF_RSH: b3 = 0xE8; break;
case BPF_ARSH: b3 = 0xF8; break;
}
EMIT2(0xD3, add_1reg(b3, dst_reg));
if (src_reg != BPF_REG_4)
EMIT1(0x59); /* pop rcx */
if (insn->dst_reg == BPF_REG_4)
/* mov dst_reg, r11 */
EMIT_mov(insn->dst_reg, AUX_REG);
break;
case BPF_ALU | BPF_END | BPF_FROM_BE:
switch (imm32) {
case 16:
/* emit 'ror %ax, 8' to swap lower 2 bytes */
EMIT1(0x66);
if (is_ereg(dst_reg))
EMIT1(0x41);
EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
/* emit 'movzwl eax, ax' */
if (is_ereg(dst_reg))
EMIT3(0x45, 0x0F, 0xB7);
else
EMIT2(0x0F, 0xB7);
EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
break;
case 32:
/* emit 'bswap eax' to swap lower 4 bytes */
if (is_ereg(dst_reg))
EMIT2(0x41, 0x0F);
else
EMIT1(0x0F);
EMIT1(add_1reg(0xC8, dst_reg));
break;
case 64:
/* emit 'bswap rax' to swap 8 bytes */
EMIT3(add_1mod(0x48, dst_reg), 0x0F,
add_1reg(0xC8, dst_reg));
break;
}
break;
case BPF_ALU | BPF_END | BPF_FROM_LE:
switch (imm32) {
case 16:
/* emit 'movzwl eax, ax' to zero extend 16-bit
* into 64 bit
*/
if (is_ereg(dst_reg))
EMIT3(0x45, 0x0F, 0xB7);
else
EMIT2(0x0F, 0xB7);
EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
break;
case 32:
/* emit 'mov eax, eax' to clear upper 32-bits */
if (is_ereg(dst_reg))
EMIT1(0x45);
EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
break;
case 64:
/* nop */
break;
}
break;
/* ST: *(u8*)(dst_reg + off) = imm */
case BPF_ST | BPF_MEM | BPF_B:
if (is_ereg(dst_reg))
EMIT2(0x41, 0xC6);
else
EMIT1(0xC6);
goto st;
case BPF_ST | BPF_MEM | BPF_H:
if (is_ereg(dst_reg))
EMIT3(0x66, 0x41, 0xC7);
else
EMIT2(0x66, 0xC7);
goto st;
case BPF_ST | BPF_MEM | BPF_W:
if (is_ereg(dst_reg))
EMIT2(0x41, 0xC7);
else
EMIT1(0xC7);
goto st;
case BPF_ST | BPF_MEM | BPF_DW:
EMIT2(add_1mod(0x48, dst_reg), 0xC7);
st: if (is_imm8(insn->off))
EMIT2(add_1reg(0x40, dst_reg), insn->off);
else
EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
break;
/* STX: *(u8*)(dst_reg + off) = src_reg */
case BPF_STX | BPF_MEM | BPF_B:
/* emit 'mov byte ptr [rax + off], al' */
if (is_ereg(dst_reg) || is_ereg(src_reg) ||
/* have to add extra byte for x86 SIL, DIL regs */
src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
else
EMIT1(0x88);
goto stx;
case BPF_STX | BPF_MEM | BPF_H:
if (is_ereg(dst_reg) || is_ereg(src_reg))
EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
else
EMIT2(0x66, 0x89);
goto stx;
case BPF_STX | BPF_MEM | BPF_W:
if (is_ereg(dst_reg) || is_ereg(src_reg))
EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
else
EMIT1(0x89);
goto stx;
case BPF_STX | BPF_MEM | BPF_DW:
EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
stx: if (is_imm8(insn->off))
EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
else
EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
insn->off);
break;
/* LDX: dst_reg = *(u8*)(src_reg + off) */
case BPF_LDX | BPF_MEM | BPF_B:
/* emit 'movzx rax, byte ptr [rax + off]' */
EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
goto ldx;
case BPF_LDX | BPF_MEM | BPF_H:
/* emit 'movzx rax, word ptr [rax + off]' */
EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
goto ldx;
case BPF_LDX | BPF_MEM | BPF_W:
/* emit 'mov eax, dword ptr [rax+0x14]' */
if (is_ereg(dst_reg) || is_ereg(src_reg))
EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
else
EMIT1(0x8B);
goto ldx;
case BPF_LDX | BPF_MEM | BPF_DW:
/* emit 'mov rax, qword ptr [rax+0x14]' */
EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
ldx: /* if insn->off == 0 we can save one extra byte, but
* special case of x86 r13 which always needs an offset
* is not worth the hassle
*/
if (is_imm8(insn->off))
EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
else
EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
insn->off);
break;
/* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
case BPF_STX | BPF_XADD | BPF_W:
/* emit 'lock add dword ptr [rax + off], eax' */
if (is_ereg(dst_reg) || is_ereg(src_reg))
EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
else
EMIT2(0xF0, 0x01);
goto xadd;
case BPF_STX | BPF_XADD | BPF_DW:
EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
xadd: if (is_imm8(insn->off))
EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
else
EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
insn->off);
break;
/* call */
case BPF_JMP | BPF_CALL:
func = (u8 *) __bpf_call_base + imm32;
jmp_offset = func - (image + addrs[i]);
if (seen_ld_abs) {
reload_skb_data = bpf_helper_changes_pkt_data(func);
if (reload_skb_data) {
EMIT1(0x57); /* push %rdi */
jmp_offset += 22; /* pop, mov, sub, mov */
} else {
EMIT2(0x41, 0x52); /* push %r10 */
EMIT2(0x41, 0x51); /* push %r9 */
/* need to adjust jmp offset, since
* pop %r9, pop %r10 take 4 bytes after call insn
*/
jmp_offset += 4;
}
}
if (!imm32 || !is_simm32(jmp_offset)) {
pr_err("unsupported bpf func %d addr %p image %p\n",
imm32, func, image);
return -EINVAL;
}
EMIT1_off32(0xE8, jmp_offset);
if (seen_ld_abs) {
if (reload_skb_data) {
EMIT1(0x5F); /* pop %rdi */
emit_load_skb_data_hlen(&prog);
} else {
EMIT2(0x41, 0x59); /* pop %r9 */
EMIT2(0x41, 0x5A); /* pop %r10 */
}
}
break;
case BPF_JMP | BPF_TAIL_CALL:
emit_bpf_tail_call(&prog);
break;
/* cond jump */
case BPF_JMP | BPF_JEQ | BPF_X:
case BPF_JMP | BPF_JNE | BPF_X:
case BPF_JMP | BPF_JGT | BPF_X:
case BPF_JMP | BPF_JLT | BPF_X:
case BPF_JMP | BPF_JGE | BPF_X:
case BPF_JMP | BPF_JLE | BPF_X:
case BPF_JMP | BPF_JSGT | BPF_X:
case BPF_JMP | BPF_JSLT | BPF_X:
case BPF_JMP | BPF_JSGE | BPF_X:
case BPF_JMP | BPF_JSLE | BPF_X:
/* cmp dst_reg, src_reg */
EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x39,
add_2reg(0xC0, dst_reg, src_reg));
goto emit_cond_jmp;
case BPF_JMP | BPF_JSET | BPF_X:
/* test dst_reg, src_reg */
EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x85,
add_2reg(0xC0, dst_reg, src_reg));
goto emit_cond_jmp;
case BPF_JMP | BPF_JSET | BPF_K:
/* test dst_reg, imm32 */
EMIT1(add_1mod(0x48, dst_reg));
EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
goto emit_cond_jmp;
case BPF_JMP | BPF_JEQ | BPF_K:
case BPF_JMP | BPF_JNE | BPF_K:
case BPF_JMP | BPF_JGT | BPF_K:
case BPF_JMP | BPF_JLT | BPF_K:
case BPF_JMP | BPF_JGE | BPF_K:
case BPF_JMP | BPF_JLE | BPF_K:
case BPF_JMP | BPF_JSGT | BPF_K:
case BPF_JMP | BPF_JSLT | BPF_K:
case BPF_JMP | BPF_JSGE | BPF_K:
case BPF_JMP | BPF_JSLE | BPF_K:
/* cmp dst_reg, imm8/32 */
EMIT1(add_1mod(0x48, dst_reg));
if (is_imm8(imm32))
EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
else
EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
emit_cond_jmp: /* convert BPF opcode to x86 */
switch (BPF_OP(insn->code)) {
case BPF_JEQ:
jmp_cond = X86_JE;
break;
case BPF_JSET:
case BPF_JNE:
jmp_cond = X86_JNE;
break;
case BPF_JGT:
/* GT is unsigned '>', JA in x86 */
jmp_cond = X86_JA;
break;
case BPF_JLT:
/* LT is unsigned '<', JB in x86 */
jmp_cond = X86_JB;
break;
case BPF_JGE:
/* GE is unsigned '>=', JAE in x86 */
jmp_cond = X86_JAE;
break;
case BPF_JLE:
/* LE is unsigned '<=', JBE in x86 */
jmp_cond = X86_JBE;
break;
case BPF_JSGT:
/* signed '>', GT in x86 */
jmp_cond = X86_JG;
break;
case BPF_JSLT:
/* signed '<', LT in x86 */
jmp_cond = X86_JL;
break;
case BPF_JSGE:
/* signed '>=', GE in x86 */
jmp_cond = X86_JGE;
break;
case BPF_JSLE:
/* signed '<=', LE in x86 */
jmp_cond = X86_JLE;
break;
default: /* to silence gcc warning */
return -EFAULT;
}
jmp_offset = addrs[i + insn->off] - addrs[i];
if (is_imm8(jmp_offset)) {
EMIT2(jmp_cond, jmp_offset);
} else if (is_simm32(jmp_offset)) {
EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
} else {
pr_err("cond_jmp gen bug %llx\n", jmp_offset);
return -EFAULT;
}
break;
case BPF_JMP | BPF_JA:
jmp_offset = addrs[i + insn->off] - addrs[i];
if (!jmp_offset)
/* optimize out nop jumps */
break;
emit_jmp:
if (is_imm8(jmp_offset)) {
EMIT2(0xEB, jmp_offset);
} else if (is_simm32(jmp_offset)) {
EMIT1_off32(0xE9, jmp_offset);
} else {
pr_err("jmp gen bug %llx\n", jmp_offset);
return -EFAULT;
}
break;
case BPF_LD | BPF_IND | BPF_W:
func = sk_load_word;
goto common_load;
case BPF_LD | BPF_ABS | BPF_W:
func = CHOOSE_LOAD_FUNC(imm32, sk_load_word);
common_load:
ctx->seen_ld_abs = seen_ld_abs = true;
jmp_offset = func - (image + addrs[i]);
if (!func || !is_simm32(jmp_offset)) {
pr_err("unsupported bpf func %d addr %p image %p\n",
imm32, func, image);
return -EINVAL;
}
if (BPF_MODE(insn->code) == BPF_ABS) {
/* mov %esi, imm32 */
EMIT1_off32(0xBE, imm32);
} else {
/* mov %rsi, src_reg */
EMIT_mov(BPF_REG_2, src_reg);
if (imm32) {
if (is_imm8(imm32))
/* add %esi, imm8 */
EMIT3(0x83, 0xC6, imm32);
else
/* add %esi, imm32 */
EMIT2_off32(0x81, 0xC6, imm32);
}
}
/* skb pointer is in R6 (%rbx), it will be copied into
* %rdi if skb_copy_bits() call is necessary.
* sk_load_* helpers also use %r10 and %r9d.
* See bpf_jit.S
*/
if (seen_ax_reg)
/* r10 = skb->data, mov %r10, off32(%rbx) */
EMIT3_off32(0x4c, 0x8b, 0x93,
offsetof(struct sk_buff, data));
EMIT1_off32(0xE8, jmp_offset); /* call */
break;
case BPF_LD | BPF_IND | BPF_H:
func = sk_load_half;
goto common_load;
case BPF_LD | BPF_ABS | BPF_H:
func = CHOOSE_LOAD_FUNC(imm32, sk_load_half);
goto common_load;
case BPF_LD | BPF_IND | BPF_B:
func = sk_load_byte;
goto common_load;
case BPF_LD | BPF_ABS | BPF_B:
func = CHOOSE_LOAD_FUNC(imm32, sk_load_byte);
goto common_load;
case BPF_JMP | BPF_EXIT:
if (seen_exit) {
jmp_offset = ctx->cleanup_addr - addrs[i];
goto emit_jmp;
}
seen_exit = true;
/* update cleanup_addr */
ctx->cleanup_addr = proglen;
/* mov rbx, qword ptr [rbp+0] */
EMIT4(0x48, 0x8B, 0x5D, 0);
/* mov r13, qword ptr [rbp+8] */
EMIT4(0x4C, 0x8B, 0x6D, 8);
/* mov r14, qword ptr [rbp+16] */
EMIT4(0x4C, 0x8B, 0x75, 16);
/* mov r15, qword ptr [rbp+24] */
EMIT4(0x4C, 0x8B, 0x7D, 24);
/* add rbp, AUX_STACK_SPACE */
EMIT4(0x48, 0x83, 0xC5, AUX_STACK_SPACE);
EMIT1(0xC9); /* leave */
EMIT1(0xC3); /* ret */
break;
default:
/* By design x64 JIT should support all BPF instructions
* This error will be seen if new instruction was added
* to interpreter, but not to JIT
* or if there is junk in bpf_prog
*/
pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
return -EINVAL;
}
ilen = prog - temp;
if (ilen > BPF_MAX_INSN_SIZE) {
pr_err("bpf_jit: fatal insn size error\n");
return -EFAULT;
}
if (image) {
if (unlikely(proglen + ilen > oldproglen)) {
pr_err("bpf_jit: fatal error\n");
return -EFAULT;
}
memcpy(image + proglen, temp, ilen);
}
proglen += ilen;
addrs[i] = proglen;
prog = temp;
}
return proglen;
}
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
{
struct bpf_binary_header *header = NULL;
struct bpf_prog *tmp, *orig_prog = prog;
int proglen, oldproglen = 0;
struct jit_context ctx = {};
bool tmp_blinded = false;
u8 *image = NULL;
int *addrs;
int pass;
int i;
if (!prog->jit_requested)
return orig_prog;
tmp = bpf_jit_blind_constants(prog);
/* If blinding was requested and we failed during blinding,
* we must fall back to the interpreter.
*/
if (IS_ERR(tmp))
return orig_prog;
if (tmp != prog) {
tmp_blinded = true;
prog = tmp;
}
addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
if (!addrs) {
prog = orig_prog;
goto out;
}
/* Before first pass, make a rough estimation of addrs[]
* each bpf instruction is translated to less than 64 bytes
*/
for (proglen = 0, i = 0; i < prog->len; i++) {
proglen += 64;
addrs[i] = proglen;
}
ctx.cleanup_addr = proglen;
/* JITed image shrinks with every pass and the loop iterates
* until the image stops shrinking. Very large bpf programs
* may converge on the last pass. In such case do one more
* pass to emit the final image
*/
for (pass = 0; pass < 10 || image; pass++) {
proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
if (proglen <= 0) {
image = NULL;
if (header)
bpf_jit_binary_free(header);
prog = orig_prog;
goto out_addrs;
}
if (image) {
if (proglen != oldproglen) {
pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
proglen, oldproglen);
prog = orig_prog;
goto out_addrs;
}
break;
}
if (proglen == oldproglen) {
header = bpf_jit_binary_alloc(proglen, &image,
1, jit_fill_hole);
if (!header) {
prog = orig_prog;
goto out_addrs;
}
}
oldproglen = proglen;
}
if (bpf_jit_enable > 1)
bpf_jit_dump(prog->len, proglen, pass + 1, image);
if (image) {
bpf_flush_icache(header, image + proglen);
bpf_jit_binary_lock_ro(header);
prog->bpf_func = (void *)image;
prog->jited = 1;
prog->jited_len = proglen;
} else {
prog = orig_prog;
}
out_addrs:
kfree(addrs);
out:
if (tmp_blinded)
bpf_jit_prog_release_other(prog, prog == orig_prog ?
tmp : orig_prog);
return prog;
}