Viewing: uaccess.c
📄 uaccess.c (Read Only) ⬅ To go back
#include "uaccess.h"

#include "errno.h"
#include "interrupts.h"
#include "hal/mm.h"

#include <stdint.h>

/* Global flag set by hal_cpu_detect_features() when SMAP is enabled in CR4. */
extern int g_smap_enabled;

/* STAC/CLAC — toggle EFLAGS.AC for SMAP bypass.
 * Encoded as raw bytes for compatibility with older assemblers.
 * Only executed when SMAP is actually enabled to avoid #UD. */
static inline void stac(void) {
    if (g_smap_enabled)
        __asm__ volatile(".byte 0x0F, 0x01, 0xCB" ::: "memory");
}
static inline void clac(void) {
    if (g_smap_enabled)
        __asm__ volatile(".byte 0x0F, 0x01, 0xCA" ::: "memory");
}

static int x86_user_range_basic_ok(uintptr_t uaddr, size_t len) {
    if (len == 0) return 1;
    if (uaddr == 0) return 0;
    if (uaddr >= hal_mm_kernel_virt_base()) return 0;
    uintptr_t end = uaddr + len - 1;
    if (end < uaddr) return 0;
    if (end >= hal_mm_kernel_virt_base()) return 0;
    return 1;
}

static volatile int g_uaccess_active = 0;
static volatile int g_uaccess_faulted = 0;
static volatile uintptr_t g_uaccess_recover_eip = 0;

int uaccess_try_recover(uintptr_t fault_addr, struct registers* regs) {
    if (!regs) return 0;
    if (g_uaccess_active == 0) return 0;
    if (g_uaccess_recover_eip == 0) return 0;

    // Only recover faults on user addresses; kernel faults should still panic.
    if (fault_addr >= hal_mm_kernel_virt_base()) return 0;

    g_uaccess_faulted = 1;
    regs->eip = (uint32_t)g_uaccess_recover_eip;
    return 1;
}

static int x86_user_page_writable_user(uintptr_t vaddr) {
    uint32_t pi = (vaddr >> 30) & 0x3;
    uint32_t di = (vaddr >> 21) & 0x1FF;
    uint32_t ti = (vaddr >> 12) & 0x1FF;

    volatile uint64_t* pd = (volatile uint64_t*)(uintptr_t)(0xFFFFC000U + pi * 0x1000U);
    uint64_t pde = pd[di];
    if (!(pde & 0x1)) return 0;
    if (!(pde & 0x4)) return 0;

    volatile uint64_t* pt = (volatile uint64_t*)(uintptr_t)(0xFF800000U + pi * 0x200000U + di * 0x1000U);
    uint64_t pte = pt[ti];
    if (!(pte & 0x1)) return 0;
    if (!(pte & 0x4)) return 0;
    if (!(pte & 0x2)) return 0;
    return 1;
}

static int x86_user_page_present_and_user(uintptr_t vaddr) {
    uint32_t pi = (vaddr >> 30) & 0x3;
    uint32_t di = (vaddr >> 21) & 0x1FF;
    uint32_t ti = (vaddr >> 12) & 0x1FF;

    volatile uint64_t* pd = (volatile uint64_t*)(uintptr_t)(0xFFFFC000U + pi * 0x1000U);
    uint64_t pde = pd[di];
    if (!(pde & 0x1)) return 0;
    if (!(pde & 0x4)) return 0;

    volatile uint64_t* pt = (volatile uint64_t*)(uintptr_t)(0xFF800000U + pi * 0x200000U + di * 0x1000U);
    uint64_t pte = pt[ti];
    if (!(pte & 0x1)) return 0;
    if (!(pte & 0x4)) return 0;

    return 1;
}

static int x86_user_range_mapped_and_user(uintptr_t uaddr, size_t len) {
    if (!x86_user_range_basic_ok(uaddr, len)) return 0;
    if (len == 0) return 1;

    uintptr_t start = uaddr & ~(uintptr_t)0xFFF;
    uintptr_t end = (uaddr + len - 1) & ~(uintptr_t)0xFFF;
    for (uintptr_t va = start;; va += 0x1000) {
        if (!x86_user_page_present_and_user(va)) return 0;
        if (va == end) break;
    }
    return 1;
}

static int x86_user_range_writable_user(uintptr_t uaddr, size_t len) {
    if (!x86_user_range_basic_ok(uaddr, len)) return 0;
    if (len == 0) return 1;

    uintptr_t start = uaddr & ~(uintptr_t)0xFFF;
    uintptr_t end = (uaddr + len - 1) & ~(uintptr_t)0xFFF;
    for (uintptr_t va = start;; va += 0x1000) {
        if (!x86_user_page_writable_user(va)) return 0;
        if (va == end) break;
    }
    return 1;
}

int user_range_ok(const void* user_ptr, size_t len) {
    uintptr_t uaddr = (uintptr_t)user_ptr;
    return x86_user_range_mapped_and_user(uaddr, len);
}

int copy_from_user(void* dst, const void* src_user, size_t len) {
    if (len == 0) return 0;
    if (!user_range_ok(src_user, len)) return -EFAULT;

    g_uaccess_faulted = 0;
    g_uaccess_recover_eip = (uintptr_t)&&uaccess_fault;
    g_uaccess_active = 1;

    stac();
    uintptr_t up = (uintptr_t)src_user;
    for (size_t i = 0; i < len; i++) {
        ((uint8_t*)dst)[i] = ((const volatile uint8_t*)up)[i];
    }
    clac();

    g_uaccess_active = 0;
    g_uaccess_recover_eip = 0;
    if (g_uaccess_faulted) return -EFAULT;
    return 0;

uaccess_fault:
    clac();
    g_uaccess_active = 0;
    g_uaccess_faulted = 0;
    g_uaccess_recover_eip = 0;
    return -EFAULT;
}

int copy_to_user(void* dst_user, const void* src, size_t len) {
    if (len == 0) return 0;

    if (!x86_user_range_writable_user((uintptr_t)dst_user, len)) return -EFAULT;

    g_uaccess_faulted = 0;
    g_uaccess_recover_eip = (uintptr_t)&&uaccess_fault2;
    g_uaccess_active = 1;

    stac();
    uintptr_t up = (uintptr_t)dst_user;
    for (size_t i = 0; i < len; i++) {
        ((volatile uint8_t*)up)[i] = ((const uint8_t*)src)[i];
    }
    clac();

    g_uaccess_active = 0;
    g_uaccess_recover_eip = 0;
    if (g_uaccess_faulted) return -EFAULT;
    return 0;

uaccess_fault2:
    clac();
    g_uaccess_active = 0;
    g_uaccess_faulted = 0;
    g_uaccess_recover_eip = 0;
    return -EFAULT;
}