fix: memory: protect kernel code

This commit is contained in:
Starnakin 2024-11-26 12:43:09 +01:00
parent 0c280d971b
commit da804296c6
4 changed files with 63 additions and 35 deletions

View File

@ -16,6 +16,8 @@
#define PAGE_MASK 0xFFFFF000 #define PAGE_MASK 0xFFFFF000
#define HEAP_END 0xC0000000 #define HEAP_END 0xC0000000
#define HEAP_START ((uint32_t) & _kernel_end - HEAP_END) #define HEAP_START ((uint32_t) & _kernel_end - HEAP_END)
#define KERNEL_START ((uint32_t) & _kernel_start)
#define KERNEL_END ((uint32_t) & _kernel_end - HEAP_END)
#define PT_START 256 #define PT_START 256
#define GET_PAGE_ADDR(pd_index, pt_index) \ #define GET_PAGE_ADDR(pd_index, pt_index) \

View File

@ -2,3 +2,5 @@
#define CEIL(x, y) (((x) + (y) - 1) / (y)) #define CEIL(x, y) (((x) + (y) - 1) / (y))
#define ARRAY_SIZE(ptr) (sizeof(ptr) / sizeof(ptr[0])) #define ARRAY_SIZE(ptr) (sizeof(ptr) / sizeof(ptr[0]))
#define ROUND_CEIL(x, y) (CEIL(x, y) * y)
#define ROUND_FLOOR(x, y) ((x / y) * y)

View File

@ -68,13 +68,16 @@ void kernel_main(multiboot_info_t *mbd, uint32_t magic)
"Martin 03:50, 22 March 2009 (UTC)\n"); "Martin 03:50, 22 March 2009 (UTC)\n");
// PRINT_PTR(alloc_frame()); // PRINT_PTR(alloc_frame());
/*void *ptr; if (false) {
while ((ptr = alloc_pages(PAGE_SIZE * 1020))) { void *ptr;
while ((ptr = alloc_pages(PAGE_SIZE * 1))) {
if (ptr) if (ptr)
memset(ptr, ~0, PAGE_SIZE * 1020); memset(ptr, ~0, PAGE_SIZE * 1);
}*/ }
/* vmalloc(10); */ } else {
while (vmalloc(10)) while (vmalloc(10))
; ;
}
/* vmalloc(10); */
shell_init(); shell_init();
} }

View File

@ -60,26 +60,46 @@ static void lst_add_back(struct frame_zone **root, struct frame_zone *element)
static void add_frame_node(multiboot_memory_map_t *mmmt) static void add_frame_node(multiboot_memory_map_t *mmmt)
{ {
static uint32_t index; static uint32_t index;
void *zone = (void *)mmmt->addr;
// Kernel code on the block /**
if (HEAP_START >= zone + mmmt->len) * # = kernel code
* - = blank
*/
uint64_t start_addr = mmmt->addr;
uint64_t end_addr = mmmt->addr + mmmt->len;
uint64_t len = mmmt->len;
/** Kernel code cover all the block
* this situation:
* #######################
*/
if (KERNEL_START <= start_addr && KERNEL_END >= end_addr)
return; return;
// KERNEL code partially on the block /** Kernel code start on the block
if (HEAP_START >= zone) { * this situation:
const uint32_t start_space = * --------###############
CEIL(HEAP_START, PAGE_SIZE) * PAGE_SIZE; */
const uint32_t len = mmmt->len - (start_space - (uint32_t)zone); if (KERNEL_START > start_addr && KERNEL_START <= end_addr) {
mmmt->len = CEIL(len, PAGE_SIZE) * PAGE_SIZE; len = ROUND_FLOOR(KERNEL_START - start_addr, PAGE_SIZE);
zone = (void *)start_space;
} }
/** Kernel code end on the block
* this situation:
* ###############--------
*/
if (KERNEL_START <= start_addr && KERNEL_END > start_addr &&
KERNEL_END <= end_addr) {
len = len - (KERNEL_END - start_addr);
start_addr = ROUND_CEIL(KERNEL_END, PAGE_SIZE);
}
end_addr = ROUND_CEIL(start_addr + len, PAGE_SIZE);
init_page_table(frame_zones_page_table, 0); init_page_table(frame_zones_page_table, 0);
page_directory[1022] = page_directory[1022] =
((uint32_t)frame_zones_page_table - HEAP_END) | 0x03; ((uint32_t)frame_zones_page_table - HEAP_END) | 0x03;
frame_zones_page_table[index] = frame_zones_page_table[index] =
((uint32_t)zone & PAGE_MASK) | INIT_FLAGS; ((uint32_t)start_addr & PAGE_MASK) | INIT_FLAGS;
struct frame_zone *current = struct frame_zone *current =
(struct frame_zone *)GET_PAGE_ADDR(1022, index++); (struct frame_zone *)GET_PAGE_ADDR(1022, index++);
@ -91,7 +111,7 @@ static void add_frame_node(multiboot_memory_map_t *mmmt)
cause we are using non decimal number cause we are using non decimal number
nb_frame = ((size * 8) / (PAGE_SIZE * 8 + 1)) nb_frame = ((size * 8) / (PAGE_SIZE * 8 + 1))
*/ */
const uint32_t nb_frame = ((mmmt->len * 8) / (PAGE_SIZE * 8 + 1)); const uint32_t nb_frame = ((len * 8) / (PAGE_SIZE * 8 + 1));
current->first_free_frame = 0; current->first_free_frame = 0;
current->next = NULL; current->next = NULL;
@ -103,8 +123,9 @@ static void add_frame_node(multiboot_memory_map_t *mmmt)
uint32_t i = 1; uint32_t i = 1;
for (; i < CEIL(nb_frame, PAGE_SIZE); i++) for (; i < CEIL(nb_frame, PAGE_SIZE); i++)
frame_zones_page_table[index + i] = frame_zones_page_table[index + i] =
((uint32_t)zone + i * PAGE_SIZE & PAGE_MASK) | INIT_FLAGS; ((uint32_t)start_addr + i * PAGE_SIZE & PAGE_MASK) |
current->addr = zone + i * PAGE_SIZE; INIT_FLAGS;
current->addr = (void *)start_addr + i * PAGE_SIZE;
index += i - 1; index += i - 1;
lst_add_back(&head, current); lst_add_back(&head, current);
} }