#include "uaccess.h"
#include "errno.h"
#include "interrupts.h"
#include "hal/mm.h"
#include <stdint.h>
/* Global flag set by hal_cpu_detect_features() when SMAP is enabled in CR4. */
extern int g_smap_enabled;
/* STAC/CLAC — toggle EFLAGS.AC for SMAP bypass.
* Encoded as raw bytes for compatibility with older assemblers.
* Only executed when SMAP is actually enabled to avoid #UD. */
static inline void stac(void) {
if (g_smap_enabled)
__asm__ volatile(".byte 0x0F, 0x01, 0xCB" ::: "memory");
}
static inline void clac(void) {
if (g_smap_enabled)
__asm__ volatile(".byte 0x0F, 0x01, 0xCA" ::: "memory");
}
static int x86_user_range_basic_ok(uintptr_t uaddr, size_t len) {
if (len == 0) return 1;
if (uaddr == 0) return 0;
if (uaddr >= hal_mm_kernel_virt_base()) return 0;
uintptr_t end = uaddr + len - 1;
if (end < uaddr) return 0;
if (end >= hal_mm_kernel_virt_base()) return 0;
return 1;
}
static volatile int g_uaccess_active = 0;
static volatile int g_uaccess_faulted = 0;
static volatile uintptr_t g_uaccess_recover_eip = 0;
int uaccess_try_recover(uintptr_t fault_addr, struct registers* regs) {
if (!regs) return 0;
if (g_uaccess_active == 0) return 0;
if (g_uaccess_recover_eip == 0) return 0;
// Only recover faults on user addresses; kernel faults should still panic.
if (fault_addr >= hal_mm_kernel_virt_base()) return 0;
g_uaccess_faulted = 1;
regs->eip = (uint32_t)g_uaccess_recover_eip;
return 1;
}
static int x86_user_page_writable_user(uintptr_t vaddr) {
uint32_t pi = (vaddr >> 30) & 0x3;
uint32_t di = (vaddr >> 21) & 0x1FF;
uint32_t ti = (vaddr >> 12) & 0x1FF;
volatile uint64_t* pd = (volatile uint64_t*)(uintptr_t)(0xFFFFC000U + pi * 0x1000U);
uint64_t pde = pd[di];
if (!(pde & 0x1)) return 0;
if (!(pde & 0x4)) return 0;
volatile uint64_t* pt = (volatile uint64_t*)(uintptr_t)(0xFF800000U + pi * 0x200000U + di * 0x1000U);
uint64_t pte = pt[ti];
if (!(pte & 0x1)) return 0;
if (!(pte & 0x4)) return 0;
if (!(pte & 0x2)) return 0;
return 1;
}
static int x86_user_page_present_and_user(uintptr_t vaddr) {
uint32_t pi = (vaddr >> 30) & 0x3;
uint32_t di = (vaddr >> 21) & 0x1FF;
uint32_t ti = (vaddr >> 12) & 0x1FF;
volatile uint64_t* pd = (volatile uint64_t*)(uintptr_t)(0xFFFFC000U + pi * 0x1000U);
uint64_t pde = pd[di];
if (!(pde & 0x1)) return 0;
if (!(pde & 0x4)) return 0;
volatile uint64_t* pt = (volatile uint64_t*)(uintptr_t)(0xFF800000U + pi * 0x200000U + di * 0x1000U);
uint64_t pte = pt[ti];
if (!(pte & 0x1)) return 0;
if (!(pte & 0x4)) return 0;
return 1;
}
static int x86_user_range_mapped_and_user(uintptr_t uaddr, size_t len) {
if (!x86_user_range_basic_ok(uaddr, len)) return 0;
if (len == 0) return 1;
uintptr_t start = uaddr & ~(uintptr_t)0xFFF;
uintptr_t end = (uaddr + len - 1) & ~(uintptr_t)0xFFF;
for (uintptr_t va = start;; va += 0x1000) {
if (!x86_user_page_present_and_user(va)) return 0;
if (va == end) break;
}
return 1;
}
static int x86_user_range_writable_user(uintptr_t uaddr, size_t len) {
if (!x86_user_range_basic_ok(uaddr, len)) return 0;
if (len == 0) return 1;
uintptr_t start = uaddr & ~(uintptr_t)0xFFF;
uintptr_t end = (uaddr + len - 1) & ~(uintptr_t)0xFFF;
for (uintptr_t va = start;; va += 0x1000) {
if (!x86_user_page_writable_user(va)) return 0;
if (va == end) break;
}
return 1;
}
int user_range_ok(const void* user_ptr, size_t len) {
uintptr_t uaddr = (uintptr_t)user_ptr;
return x86_user_range_mapped_and_user(uaddr, len);
}
int copy_from_user(void* dst, const void* src_user, size_t len) {
if (len == 0) return 0;
if (!user_range_ok(src_user, len)) return -EFAULT;
g_uaccess_faulted = 0;
g_uaccess_recover_eip = (uintptr_t)&&uaccess_fault;
g_uaccess_active = 1;
stac();
uintptr_t up = (uintptr_t)src_user;
for (size_t i = 0; i < len; i++) {
((uint8_t*)dst)[i] = ((const volatile uint8_t*)up)[i];
}
clac();
g_uaccess_active = 0;
g_uaccess_recover_eip = 0;
if (g_uaccess_faulted) return -EFAULT;
return 0;
uaccess_fault:
clac();
g_uaccess_active = 0;
g_uaccess_faulted = 0;
g_uaccess_recover_eip = 0;
return -EFAULT;
}
int copy_to_user(void* dst_user, const void* src, size_t len) {
if (len == 0) return 0;
if (!x86_user_range_writable_user((uintptr_t)dst_user, len)) return -EFAULT;
g_uaccess_faulted = 0;
g_uaccess_recover_eip = (uintptr_t)&&uaccess_fault2;
g_uaccess_active = 1;
stac();
uintptr_t up = (uintptr_t)dst_user;
for (size_t i = 0; i < len; i++) {
((volatile uint8_t*)up)[i] = ((const uint8_t*)src)[i];
}
clac();
g_uaccess_active = 0;
g_uaccess_recover_eip = 0;
if (g_uaccess_faulted) return -EFAULT;
return 0;
uaccess_fault2:
clac();
g_uaccess_active = 0;
g_uaccess_faulted = 0;
g_uaccess_recover_eip = 0;
return -EFAULT;
}