#include "initrd.h"
#include "utils.h"
#include "heap.h"
#include "console.h"
#include "errno.h"
#include "lz4.h"
#define TAR_BLOCK 512
typedef struct {
char name[100];
char mode[8];
char uid[8];
char gid[8];
char size[12];
char mtime[12];
char chksum[8];
char typeflag;
char linkname[100];
char magic[6];
char version[2];
char uname[32];
char gname[32];
char devmajor[8];
char devminor[8];
char prefix[155];
char pad[12];
} __attribute__((packed)) tar_header_t;
typedef struct {
char name[128];
uint32_t flags;
uint32_t data_offset;
uint32_t length;
int parent;
int first_child;
int next_sibling;
} initrd_entry_t;
static uint32_t initrd_location_base = 0;
static initrd_entry_t* entries = NULL;
static fs_node_t* nodes = NULL;
static int entry_count = 0;
static int entry_cap = 0;
static uint32_t tar_parse_octal(const char* s, size_t n) {
uint32_t v = 0;
for (size_t i = 0; i < n; i++) {
if (s[i] == 0) break;
if (s[i] < '0' || s[i] > '7') continue;
v = (v << 3) + (uint32_t)(s[i] - '0');
}
return v;
}
static int tar_is_zero_block(const uint8_t* p) {
for (int i = 0; i < TAR_BLOCK; i++) {
if (p[i] != 0) return 0;
}
return 1;
}
static void str_copy_n(char* dst, size_t dst_sz, const char* src, size_t src_n) {
if (dst_sz == 0) return;
size_t i = 0;
for (; i + 1 < dst_sz && i < src_n; i++) {
char c = src[i];
if (c == 0) break;
dst[i] = c;
}
dst[i] = 0;
}
static int entry_alloc(void) {
if (entry_count == entry_cap) {
int new_cap = (entry_cap == 0) ? 32 : entry_cap * 2;
initrd_entry_t* new_entries = (initrd_entry_t*)kmalloc(sizeof(initrd_entry_t) * (size_t)new_cap);
fs_node_t* new_nodes = (fs_node_t*)kmalloc(sizeof(fs_node_t) * (size_t)new_cap);
if (!new_entries || !new_nodes) return -ENOMEM;
if (entries) {
memcpy(new_entries, entries, sizeof(initrd_entry_t) * (size_t)entry_count);
kfree(entries);
}
if (nodes) {
memcpy(new_nodes, nodes, sizeof(fs_node_t) * (size_t)entry_count);
kfree(nodes);
}
entries = new_entries;
nodes = new_nodes;
entry_cap = new_cap;
}
int idx = entry_count++;
memset(&entries[idx], 0, sizeof(entries[idx]));
memset(&nodes[idx], 0, sizeof(nodes[idx]));
entries[idx].parent = -1;
entries[idx].first_child = -1;
entries[idx].next_sibling = -1;
return idx;
}
static int entry_find_child(int parent, const char* name) {
int c = entries[parent].first_child;
while (c != -1) {
if (strcmp(entries[c].name, name) == 0) return c;
c = entries[c].next_sibling;
}
return -1;
}
static int entry_add_child(int parent, const char* name, uint32_t flags) {
int idx = entry_alloc();
if (idx < 0) return idx;
strcpy(entries[idx].name, name);
entries[idx].flags = flags;
entries[idx].parent = parent;
entries[idx].next_sibling = entries[parent].first_child;
entries[parent].first_child = idx;
return idx;
}
static int ensure_dir(int parent, const char* name) {
int child = entry_find_child(parent, name);
if (child != -1) return child;
return entry_add_child(parent, name, FS_DIRECTORY);
}
static int ensure_path_dirs(int root_idx, const char* path, char* leaf_out, size_t leaf_out_sz) {
int cur = root_idx;
const char* p = path;
if (!path || !leaf_out || leaf_out_sz == 0) return -EINVAL;
while (*p == '/') p++;
char part[128];
while (*p != 0) {
size_t i = 0;
while (*p != 0 && *p != '/') {
if (i + 1 < sizeof(part)) part[i++] = *p;
p++;
}
part[i] = 0;
while (*p == '/') p++;
if (part[0] == 0) continue;
if (*p == 0) {
str_copy_n(leaf_out, leaf_out_sz, part, strlen(part));
return cur;
}
cur = ensure_dir(cur, part);
if (cur < 0) return cur;
}
return -EINVAL;
}
static uint32_t initrd_read_impl(fs_node_t* node, uint32_t offset, uint32_t size, uint8_t* buffer) {
if (!node) return 0;
uint32_t idx = node->inode;
if ((int)idx < 0 || (int)idx >= entry_count) return 0;
initrd_entry_t* e = &entries[idx];
if ((e->flags & FS_FILE) == 0) return 0;
if (offset > e->length) return 0;
if (offset + size > e->length) size = e->length - offset;
const uint8_t* src = (const uint8_t*)(initrd_location_base + e->data_offset + offset);
memcpy(buffer, src, size);
return size;
}
static struct fs_node* initrd_finddir(struct fs_node* node, const char* name) {
if (!node || !name) return 0;
int parent = (int)node->inode;
if (parent < 0 || parent >= entry_count) return 0;
int c = entries[parent].first_child;
while (c != -1) {
if (strcmp(entries[c].name, name) == 0) {
return &nodes[c];
}
c = entries[c].next_sibling;
}
return 0;
}
static const struct file_operations initrd_file_ops = {
.read = initrd_read_impl,
};
static const struct file_operations initrd_dir_ops = {0};
static int initrd_readdir(struct fs_node* node, uint32_t* inout_index, void* buf, uint32_t buf_len) {
if (!node || !inout_index || !buf) return -1;
if (node->flags != FS_DIRECTORY) return -1;
if (buf_len < sizeof(struct vfs_dirent)) return -1;
int parent = (int)node->inode;
if (parent < 0 || parent >= entry_count) return -1;
uint32_t idx = *inout_index;
uint32_t cap = buf_len / (uint32_t)sizeof(struct vfs_dirent);
struct vfs_dirent* ents = (struct vfs_dirent*)buf;
uint32_t written = 0;
while (written < cap) {
struct vfs_dirent e;
memset(&e, 0, sizeof(e));
if (idx == 0) {
e.d_ino = node->inode;
e.d_type = FS_DIRECTORY;
strcpy(e.d_name, ".");
} else if (idx == 1) {
int pi = entries[parent].parent;
e.d_ino = (pi >= 0) ? (uint32_t)pi : node->inode;
e.d_type = FS_DIRECTORY;
strcpy(e.d_name, "..");
} else {
/* Walk the child linked list to find the (idx-2)th child */
uint32_t skip = idx - 2;
int c = entries[parent].first_child;
while (c != -1 && skip > 0) {
c = entries[c].next_sibling;
skip--;
}
if (c == -1) break;
e.d_ino = (uint32_t)c;
e.d_type = (uint8_t)entries[c].flags;
strcpy(e.d_name, entries[c].name);
}
e.d_reclen = (uint16_t)sizeof(e);
ents[written++] = e;
idx++;
}
*inout_index = idx;
return (int)(written * (uint32_t)sizeof(struct vfs_dirent));
}
static const struct inode_operations initrd_dir_iops = {
.lookup = initrd_finddir,
.readdir = initrd_readdir,
};
static void initrd_finalize_nodes(void) {
for (int i = 0; i < entry_count; i++) {
fs_node_t* n = &nodes[i];
initrd_entry_t* e = &entries[i];
strcpy(n->name, e->name);
n->inode = (uint32_t)i;
n->length = e->length;
n->flags = e->flags;
if (e->flags & FS_FILE) {
n->f_ops = &initrd_file_ops;
} else if (e->flags & FS_DIRECTORY) {
n->f_ops = &initrd_dir_ops;
n->i_ops = &initrd_dir_iops;
}
}
}
fs_node_t* initrd_init(uint32_t location, uint32_t size) {
const uint8_t* raw = (const uint8_t*)(uintptr_t)location;
uint8_t* decomp_buf = NULL;
/* Detect LZ4-compressed initrd */
uint32_t magic32 = (uint32_t)raw[0] | ((uint32_t)raw[1] << 8) |
((uint32_t)raw[2] << 16) | ((uint32_t)raw[3] << 24);
if (magic32 == LZ4_FRAME_MAGIC) {
/* Official LZ4 Frame format — extract content size from header */
uint8_t flg = raw[4];
uint32_t orig_sz = 0;
if (flg & 0x08) { /* Content Size flag */
orig_sz = (uint32_t)raw[6] | ((uint32_t)raw[7] << 8) |
((uint32_t)raw[8] << 16) | ((uint32_t)raw[9] << 24);
} else {
orig_sz = 4U * 1024U * 1024U;
}
decomp_buf = (uint8_t*)kmalloc(orig_sz);
if (!decomp_buf) {
kprintf("[INITRD] OOM decompressing LZ4 (%u bytes)\n", orig_sz);
return 0;
}
int ret = lz4_decompress_frame(raw, size, decomp_buf, orig_sz);
if (ret < 0) {
kprintf("[INITRD] LZ4 Frame decompress failed (ret=%d)\n", ret);
kfree(decomp_buf);
return 0;
}
kprintf("[INITRD] LZ4: %u -> %d bytes\n", size, ret);
location = (uint32_t)(uintptr_t)decomp_buf;
} else if (magic32 == LZ4B_MAGIC_U32) {
/* Legacy LZ4B format (backward compatibility) */
uint32_t orig_sz = (uint32_t)raw[4] | ((uint32_t)raw[5] << 8) |
((uint32_t)raw[6] << 16) | ((uint32_t)raw[7] << 24);
uint32_t comp_sz = (uint32_t)raw[8] | ((uint32_t)raw[9] << 8) |
((uint32_t)raw[10] << 16) | ((uint32_t)raw[11] << 24);
decomp_buf = (uint8_t*)kmalloc(orig_sz);
if (!decomp_buf) {
kprintf("[INITRD] OOM decompressing LZ4 (%u bytes)\n", orig_sz);
return 0;
}
int ret = lz4_decompress_block(raw + LZ4B_HDR_SIZE, comp_sz,
decomp_buf, orig_sz);
if (ret < 0 || (uint32_t)ret != orig_sz) {
kprintf("[INITRD] LZ4 decompress failed (ret=%d, expected=%u)\n",
ret, orig_sz);
kfree(decomp_buf);
return 0;
}
kprintf("[INITRD] LZ4: %u -> %u bytes\n", comp_sz, orig_sz);
location = (uint32_t)(uintptr_t)decomp_buf;
}
initrd_location_base = location;
// Initialize root
entry_count = 0;
int root = entry_alloc();
if (root < 0) { kfree(decomp_buf); return 0; }
strcpy(entries[root].name, "");
entries[root].flags = FS_DIRECTORY;
entries[root].data_offset = 0;
entries[root].length = 0;
entries[root].parent = -1;
const uint8_t* p = (const uint8_t*)(uintptr_t)location;
int files = 0;
while (1) {
if (tar_is_zero_block(p)) break;
const tar_header_t* h = (const tar_header_t*)p;
char name[256];
name[0] = 0;
if (h->prefix[0]) {
str_copy_n(name, sizeof(name), h->prefix, sizeof(h->prefix));
size_t cur = strlen(name);
if (cur + 1 < sizeof(name)) {
name[cur] = '/';
name[cur + 1] = 0;
}
size_t rem = sizeof(name) - strlen(name) - 1;
str_copy_n(name + strlen(name), rem + 1, h->name, sizeof(h->name));
} else {
str_copy_n(name, sizeof(name), h->name, sizeof(h->name));
}
uint32_t size = tar_parse_octal(h->size, sizeof(h->size));
char tf = h->typeflag;
if (tf == 0) tf = '0';
// Normalize: strip leading './'
if (name[0] == '.' && name[1] == '/') {
size_t l = strlen(name);
for (size_t i = 0; i + 2 <= l; i++) {
name[i] = name[i + 2];
}
}
// Directories in tar often end with '/'
size_t nlen = strlen(name);
if (nlen && name[nlen - 1] == '/') {
name[nlen - 1] = 0;
tf = '5';
}
if (name[0] != 0) {
char leaf[128];
int parent = ensure_path_dirs(root, name, leaf, sizeof(leaf));
if (parent >= 0) {
if (tf == '5') {
(void)ensure_dir(parent, leaf);
} else {
int existing = entry_find_child(parent, leaf);
int idx = existing;
if (idx == -1) {
idx = entry_add_child(parent, leaf, FS_FILE);
}
if (idx >= 0) {
entries[idx].flags = FS_FILE;
entries[idx].data_offset = (uint32_t)((uintptr_t)(p + TAR_BLOCK) - (uintptr_t)location);
entries[idx].length = size;
files++;
}
}
}
}
uint32_t adv = TAR_BLOCK + ((size + (TAR_BLOCK - 1)) & ~(TAR_BLOCK - 1));
p += adv;
}
initrd_finalize_nodes();
kprintf("[INITRD] Found %d files.\n", files);
return &nodes[root];
}